Commit 77dfb147 authored by Markus Scheidgen's avatar Markus Scheidgen
Browse files

Fixes arround archive querying.

parent 661cb128
Pipeline #72367 failed with stages
in 17 minutes and 51 seconds
......@@ -30,7 +30,7 @@ import urllib.parse
import metainfo
from nomad.files import UploadFiles, Restricted
from nomad.archive import query_archive
from nomad.archive import query_archive, ArchiveQueryError
from nomad import search, config
from nomad.app import common
......@@ -265,7 +265,7 @@ class ArchiveQueryResource(Resource):
search_request.owner('all')
apply_search_parameters(search_request, query)
search_request.include('calc_id', 'upload_id', 'with_embargo', 'parser_name')
search_request.include('calc_id', 'upload_id', 'with_embargo', 'published', 'parser_name')
try:
if scroll:
......@@ -288,11 +288,18 @@ class ArchiveQueryResource(Resource):
calcs = results['results']
upload_files = None
current_upload_id = None
check_restricted = g.user is not None
for entry in calcs:
with_embargo = entry['with_embargo']
if (not entry['published'] or with_embargo) and not check_restricted:
continue
upload_id = entry['upload_id']
calc_id = entry['calc_id']
if upload_files is None or current_upload_id != upload_id:
if upload_files is not None:
check_restricted = g.user is not None # maybe the user has restricted access for the next upload
upload_files.close()
upload_files = UploadFiles.get(upload_id, create_authorization_predicate(upload_id))
......@@ -302,7 +309,7 @@ class ArchiveQueryResource(Resource):
current_upload_id = upload_id
if entry['with_embargo']:
if with_embargo:
access = 'restricted'
else:
access = 'public'
......@@ -315,10 +322,11 @@ class ArchiveQueryResource(Resource):
'archive': query_archive(
archive, {calc_id: query_schema})[calc_id]
})
except ArchiveQueryError as e:
abort(400, str(e))
except Restricted:
# optimize and not access restricted for same upload again
pass
check_restricted = False
if upload_files is not None:
upload_files.close()
......
# Copyright 2018 Markus Scheidgen, Alvin Noe Ladines
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an"AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable, Any, Tuple, Dict, BinaryIO, Union, List, cast
from io import BytesIO, BufferedReader
from collections.abc import Mapping, Sequence
......@@ -29,6 +43,15 @@ def adjust_uuid_size(uuid):
class ArchiveError(Exception):
''' An error that indicates a broken archive. '''
pass
class ArchiveQueryError(Exception):
'''
An error that indicates that an archive query is either not valid or does not fit to
the queried archive.
'''
pass
......@@ -511,60 +534,88 @@ def read_archive(file_or_path: str, **kwargs) -> ArchiveReader:
return ArchiveReader(file_or_path, **kwargs)
def query_archive(f_or_archive_reader: Union[str, ArchiveReader, BytesIO], query_dict: dict):
__query_archive_key_pattern = re.compile(r'(\w+)(\[([-?0-9]*)(:([-?0-9]*))?\])?')
def query_archive(f_or_archive_reader: Union[str, ArchiveReader, BytesIO], query_dict: dict, **kwargs):
def _to_son(data):
if isinstance(data, (ArchiveList, List)):
data = [_to_son(item) for item in data]
elif isinstance(data, ArchiveObject):
data = data.to_dict()
return data
def _load_data(query_dict: Dict[str, Any], archive_item: ArchiveObject, main_section: bool = False):
if not isinstance(query_dict, dict):
if isinstance(archive_item, ArchiveObject):
return archive_item.to_dict()
elif isinstance(archive_item, ArchiveList):
return list(archive_item)
else:
return archive_item
return _to_son(archive_item)
res = {}
result = {}
for key, val in query_dict.items():
key = key.strip()
# process array indices
match = re.match(r'(\w+)\[([-?0-9:]+)\]', key)
match = __query_archive_key_pattern.match(key)
index: Tuple[int, int] = None
if match:
archive_key = match.group(1)
index_str = match.group(2)
match = re.match(r'([-?0-9]*):([-?0-9]*)', index_str)
if match:
index = (
0 if match.group(1) == '' else int(match.group(1)),
None if match.group(2) == '' else int(match.group(2)))
key = match.group(1)
if match.group(2) is not None:
first_index, last_index = None, None
group = match.group(3)
first_index = None if group == '' else int(group)
if match.group(4) is not None:
group = match.group(5)
last_index = None if group == '' else int(group)
index = (0 if first_index is None else first_index, last_index)
else:
index = (first_index, first_index + 1) # one item
else:
index = int(index_str) # type: ignore
key = archive_key
index = None
else:
archive_key = key
index = None
raise ArchiveQueryError('invalid key format: %s' % key)
# support for shorter uuids
archive_key = key.split('[')[0]
if main_section:
archive_key = adjust_uuid_size(key)
else:
archive_key = key
try:
archive_child = archive_item[archive_key]
is_list = isinstance(archive_child, (ArchiveList, list))
if index is None and is_list:
index = (0, None)
elif index is not None and not is_list:
raise ArchiveQueryError('cannot use list key on none list %s' % key)
if index is None:
res[key] = _load_data(val, archive_item[archive_key])
elif isinstance(index, int):
res[key] = _load_data(val, archive_item[archive_key])[index]
pass
else:
res[key] = _load_data(val, archive_item[archive_key])[index[0]: index[1]]
archive_child = archive_child[index[0]: index[1]]
except Exception:
if isinstance(archive_child, (ArchiveList, list)):
result[key] = [_load_data(val, item) for item in archive_child]
else:
result[key] = _load_data(val, archive_child)
except (KeyError, IndexError):
continue
return res
return result
if isinstance(f_or_archive_reader, ArchiveReader):
return _load_data(query_dict, f_or_archive_reader, True)
elif isinstance(f_or_archive_reader, (BytesIO, str)):
with ArchiveReader(f_or_archive_reader) as archive:
with ArchiveReader(f_or_archive_reader, **kwargs) as archive:
return _load_data(query_dict, archive, True)
else:
......
# Copyright 2018 Markus Scheidgen
# Copyright 2018 Markus Scheidgen, Alvin Noe Ladines
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
......@@ -82,12 +82,13 @@ class ArchiveQuery(Sequence):
self,
query: dict = None, query_schema: dict = None,
url: str = None, username: str = None, password: str = None,
scroll: bool = False,
scroll: bool = False, per_page: int = 10,
authentication: Union[Dict[str, str], KeycloakAuthenticator] = None, **kwargs):
self.scroll = scroll
self._scroll_id = None
self._page = 1
self._per_page = per_page
self.query: Dict[str, Any] = {
'query': {}
......@@ -133,7 +134,8 @@ class ArchiveQuery(Sequence):
scroll_config['scroll_id'] = self._scroll_id
else:
self.query.setdefault('pagination', {})['page'] = self._page
self.query.setdefault('pagination', {}).update(
page=self._page, per_page=self._per_page)
response = requests.post(url, headers=self.authentication, json=self.query)
if response.status_code != 200:
......
# Copyright 2018 Markus Scheidgen, Alvin Noe Ladines
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an"AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Any
import pytest
import msgpack
from io import BytesIO
import os.path
from nomad import utils, config
from nomad.archive import TOCPacker, write_archive, read_archive, ArchiveReader, query_archive
from nomad.archive import TOCPacker, write_archive, read_archive, ArchiveReader, ArchiveQueryError, query_archive
def create_example_uuid(index: int = 0):
......@@ -198,35 +212,49 @@ def test_read_archive_multi(example_uuid, example_entry, use_blocked_toc):
reader.get(create_example_uuid(i)) is not None
def test_query():
payload = {
'c1': {
's1': {
'ss1': [{'p1': 1.0, 'p2': 'x'}, {'p1': 1.5, 'p2': 'y'}]
},
's2': {'p1': ['a', 'b']}
test_query_example: Dict[Any, Any] = {
'c1': {
's1': {
'ss1': [{'p1': 1.0, 'p2': 'x'}, {'p1': 1.5, 'p2': 'y'}]
},
'c2': {
's1': {'ss1': [{'p1': 2.0}]},
's2': {'p1': ['c', 'd']}
}
's2': [{'p1': ['a', 'b'], 'p2': True}]
},
'c2': {
's1': {
'ss1': [{'p1': 2.0}]
},
's2': [{'p1': ['c', 'd']}]
}
}
@pytest.mark.parametrize('query,ref', [
({'c1': '*'}, {'c1': test_query_example['c1']}),
({'c1': '*', 'c2': {'s1': '*'}}, {'c1': test_query_example['c1'], 'c2': {'s1': test_query_example['c2']['s1']}}),
({'c2': {'s1': {'ss1[0]': '*'}}}, {'c2': {'s1': {'ss1': test_query_example['c2']['s1']['ss1'][0:1]}}}),
({'c1': {'s1': {'ss1[1:]': '*'}}}, {'c1': {'s1': {'ss1': test_query_example['c1']['s1']['ss1'][1:]}}}),
({'c1': {'s1': {'ss1[:2]': '*'}}}, {'c1': {'s1': {'ss1': test_query_example['c1']['s1']['ss1'][:2]}}}),
({'c1': {'s1': {'ss1[0:2]': '*'}}}, {'c1': {'s1': {'ss1': test_query_example['c1']['s1']['ss1'][0:2]}}}),
({'c1': {'s1': {'ss1[-2]': '*'}}}, {'c1': {'s1': {'ss1': test_query_example['c1']['s1']['ss1'][-2:-1]}}}),
({'c1': {'s1': {'ss1[:-1]': '*'}}}, {'c1': {'s1': {'ss1': test_query_example['c1']['s1']['ss1'][:-1]}}}),
({'c1': {'s1': {'ss1[1:-1]': '*'}}}, {'c1': {'s1': {'ss1': test_query_example['c1']['s1']['ss1'][1:-1]}}}),
({'c2': {'s1': {'ss1[-3:-1]': '*'}}}, {'c2': {'s1': {'ss1': test_query_example['c2']['s1']['ss1'][-3:-1]}}}),
({'c1': {'s2[0]': {'p1': '*'}}}, {'c1': {'s2': [{'p1': test_query_example['c1']['s2'][0]['p1']}]}}),
({'c1': {'s3': '*'}}, {'c1': {}}),
({'c1': {'s1[0]': '*'}}, ArchiveQueryError())
])
def test_query(query, ref):
f = BytesIO()
write_archive(f, 2, [(k, v) for k, v in payload.items()], entry_toc_depth=1)
write_archive(f, 2, [(k, v) for k, v in test_query_example.items()], entry_toc_depth=1)
packed_archive = f.getbuffer()
f = BytesIO(packed_archive)
assert query_archive(f, {'c1': '*'}) == {'c1': payload['c1']}
assert query_archive(f, {'c1': '*', 'c2': {'s1': '*'}}) == {'c1': payload['c1'], 'c2': {'s1': payload['c2']['s1']}}
assert query_archive(f, {'c2': {'s1': {'ss1[0]': '*'}}}) == {'c2': {'s1': {'ss1': payload['c2']['s1']['ss1'][0]}}}
assert query_archive(f, {'c1': {'s1': {'ss1[1:]': '*'}}}) == {'c1': {'s1': {'ss1': payload['c1']['s1']['ss1'][1:]}}}
assert query_archive(f, {'c1': {'s1': {'ss1[:2]': '*'}}}) == {'c1': {'s1': {'ss1': payload['c1']['s1']['ss1'][:2]}}}
assert query_archive(f, {'c1': {'s1': {'ss1[0:2]': '*'}}}) == {'c1': {'s1': {'ss1': payload['c1']['s1']['ss1'][0:2]}}}
assert query_archive(f, {'c1': {'s1': {'ss1[-2]': '*'}}}) == {'c1': {'s1': {'ss1': payload['c1']['s1']['ss1'][-2]}}}
assert query_archive(f, {'c1': {'s1': {'ss1[:-1]': '*'}}}) == {'c1': {'s1': {'ss1': payload['c1']['s1']['ss1'][:-1]}}}
assert query_archive(f, {'c1': {'s1': {'ss1[1:-1]': '*'}}}) == {'c1': {'s1': {'ss1': payload['c1']['s1']['ss1'][1:-1]}}}
assert query_archive(f, {'c2': {'s1': {'ss1[-3:-1]': '*'}}}) == {'c2': {'s1': {'ss1': payload['c2']['s1']['ss1'][-3:-1]}}}
if isinstance(ref, Exception):
with pytest.raises(ref.__class__):
query_archive(f, query)
else:
assert query_archive(f, query) == ref
def test_read_springer():
......
......@@ -3,6 +3,7 @@ import pytest
from nomad.client import query_archive
from nomad.metainfo import MSection, SubSection
from nomad.datamodel import EntryArchive
from nomad.datamodel.metainfo.public import section_run
from tests.app.test_app import BlueprintClient
......@@ -21,16 +22,21 @@ def api(client, monkeypatch):
def assert_results(
results: List[MSection],
sub_sections: List[SubSection] = None,
sub_section_defs: List[SubSection] = None,
total=1):
assert len(results) == total
for result in results:
if sub_sections:
for sub_section in result.m_def.all_sub_sections.values():
if sub_section in sub_sections:
assert len(result.m_get_sub_sections(sub_section)) > 0
else:
assert len(result.m_get_sub_sections(sub_section)) == 0
assert result.m_def == EntryArchive.m_def
if sub_section_defs:
current = result
for sub_section_def in sub_section_defs:
for other_sub_section_def in current.m_def.all_sub_sections.values():
if other_sub_section_def != sub_section_def:
assert len(current.m_get_sub_sections(other_sub_section_def)) == 0
sub_sections = current.m_get_sub_sections(sub_section_def)
assert len(sub_sections) > 0
current = sub_sections[0]
def test_query(api, published_wo_user_metadata):
......@@ -41,11 +47,13 @@ def test_query_query(api, published_wo_user_metadata):
assert_results(query_archive(query=dict(upload_id=[published_wo_user_metadata.upload_id])))
def test_query_schema(api, published_wo_user_metadata):
q_schema = {'section_run': {'section_system': '*'}}
assert_results(
query_archive(query_schema=q_schema),
sub_sections=[section_run.section_system])
@pytest.mark.parametrize('q_schema,sub_sections', [
({'section_run': '*'}, [EntryArchive.section_run]),
({'section_run': {'section_system': '*'}}, [EntryArchive.section_run, section_run.section_system]),
({'section_run[0]': {'section_system': '*'}}, [EntryArchive.section_run, section_run.section_system])
])
def test_query_schema(api, published_wo_user_metadata, q_schema, sub_sections):
assert_results(query_archive(query_schema=q_schema), sub_section_defs=sub_sections)
def test_query_scroll(api, published_wo_user_metadata):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment