Commit 309f3da1 authored by Markus Scheidgen's avatar Markus Scheidgen
Browse files

Replaced archive API pagination with an composite aggregation approach.

parent 11df2a5e
Pipeline #75101 failed with stages
in 22 minutes and 28 seconds
...@@ -3,7 +3,7 @@ from nomad.client import ArchiveQuery ...@@ -3,7 +3,7 @@ from nomad.client import ArchiveQuery
from nomad.metainfo import units from nomad.metainfo import units
# this will not be necessary, once this is the official NOMAD version # this will not be necessary, once this is the official NOMAD version
config.client.url = 'https://labdev-nomad.esc.rzg.mpg.de/fairdi/nomad/testing-major/api' config.client.url = 'http://labdev-nomad.esc.rzg.mpg.de/fairdi/nomad/testing-major/api'
query = ArchiveQuery( query = ArchiveQuery(
query={ query={
......
...@@ -233,19 +233,21 @@ class ArchiveQueryResource(Resource): ...@@ -233,19 +233,21 @@ class ArchiveQueryResource(Resource):
See ``/repo`` endpoint for documentation on the search See ``/repo`` endpoint for documentation on the search
parameters. parameters.
This endpoint uses pagination (see /repo) or id aggregation to handle large result
sets over multiple requests.
Use aggregation.after and aggregation.per_page to request a
certain page with id aggregation.
The actual data are in results and a supplementary python code (curl) to The actual data are in results and a supplementary python code (curl) to
execute search is in python (curl). execute search is in python (curl).
''' '''
try: try:
data_in = request.get_json() data_in = request.get_json()
scroll = data_in.get('scroll', None) aggregation = data_in.get('aggregation', None)
if scroll:
scroll_id = scroll.get('scroll_id')
scroll = True
pagination = data_in.get('pagination', {}) pagination = data_in.get('pagination', {})
page = pagination.get('page', 1) page = pagination.get('page', 1)
per_page = pagination.get('per_page', 10 if not scroll else 1000) per_page = pagination.get('per_page', 10)
query = data_in.get('query', {}) query = data_in.get('query', {})
...@@ -270,20 +272,19 @@ class ArchiveQueryResource(Resource): ...@@ -270,20 +272,19 @@ class ArchiveQueryResource(Resource):
search_request.owner('all') search_request.owner('all')
apply_search_parameters(search_request, query) apply_search_parameters(search_request, query)
if not aggregation:
search_request.include('calc_id', 'upload_id', 'with_embargo', 'published', 'parser_name') search_request.include('calc_id', 'upload_id', 'with_embargo', 'published', 'parser_name')
try: try:
if scroll: if aggregation:
results = search_request.execute_scrolled( results = search_request.execute_aggregated(
scroll_id=scroll_id, size=per_page, order_by='upload_id') after=aggregation.get('after'), per_page=aggregation.get('per_page', 1000),
results['scroll']['scroll'] = True includes=['with_embargo', 'published', 'parser_name'])
else: else:
results = search_request.execute_paginated( results = search_request.execute_paginated(
per_page=per_page, page=page, order_by='upload_id') per_page=per_page, page=page, order_by='upload_id')
except search.ScrollIdNotFound:
abort(400, 'The given scroll_id does not exist.')
except KeyError as e: except KeyError as e:
abort(400, str(e)) abort(400, str(e))
......
...@@ -65,10 +65,18 @@ scroll_model = api.model('Scroll', { ...@@ -65,10 +65,18 @@ scroll_model = api.model('Scroll', {
'total': fields.Integer(default=0, description='The total amount of hits for the search.'), 'total': fields.Integer(default=0, description='The total amount of hits for the search.'),
'scroll_id': fields.String(default=None, allow_null=True, description='The scroll_id that can be used to retrieve the next page.'), 'scroll_id': fields.String(default=None, allow_null=True, description='The scroll_id that can be used to retrieve the next page.'),
'size': fields.Integer(default=0, help='The size of the returned scroll page.')}) 'size': fields.Integer(default=0, help='The size of the returned scroll page.')})
''' Model used in responses with scroll. '''
aggregation_model = api.model('Aggregation', {
'after': fields.String(description='The after key for the current request.', allow_null=True),
'total': fields.Integer(default=0, description='The total amount of hits for the search.'),
'per_page': fields.Integer(default=0, help='The size of the requested page.', allow_null=True)})
''' Model used in responses with id aggregation. '''
search_model_fields = { search_model_fields = {
'pagination': fields.Nested(pagination_model, allow_null=True, skip_none=True), 'pagination': fields.Nested(pagination_model, allow_null=True, skip_none=True),
'scroll': fields.Nested(scroll_model, allow_null=True, skip_none=True), 'scroll': fields.Nested(scroll_model, allow_null=True, skip_none=True),
'aggregation': fields.Nested(aggregation_model, allow_null=True),
'results': fields.List(fields.Raw(allow_null=True, skip_none=True), description=( 'results': fields.List(fields.Raw(allow_null=True, skip_none=True), description=(
'A list of search results. Each result is a dict with quantitie names as key and ' 'A list of search results. Each result is a dict with quantitie names as key and '
'values as values'), allow_null=True, skip_none=True), 'values as values'), allow_null=True, skip_none=True),
......
...@@ -98,6 +98,10 @@ sub-sections return lists of further objects. Here we navigate the sections ``se ...@@ -98,6 +98,10 @@ sub-sections return lists of further objects. Here we navigate the sections ``se
sub-section ``section_system`` to access the quantity ``energy_total``. This quantity is a sub-section ``section_system`` to access the quantity ``energy_total``. This quantity is a
number with an attached unit (Joule), which can be converted to something else (e.g. Hartree). number with an attached unit (Joule), which can be converted to something else (e.g. Hartree).
The create query object keeps all results in memory. Keep this in mind, when you are
accessing a large amount of query results. You should use :func:`ArchiveQuery.clear`
to remove unnecessary results.
The NOMAD Metainfo The NOMAD Metainfo
__________________ __________________
...@@ -234,8 +238,6 @@ class ArchiveQuery(collections.abc.Sequence): ...@@ -234,8 +238,6 @@ class ArchiveQuery(collections.abc.Sequence):
url: Optional, override the default NOMAD API url. url: Optional, override the default NOMAD API url.
username: Optional, allows authenticated access. username: Optional, allows authenticated access.
password: Optional, allows authenticated access. password: Optional, allows authenticated access.
scroll: Use the scroll API to iterate through results. This is required when you
are accessing many 1000 results. By default, the pagination API is used.
per_page: Determine how many results are downloaded per page (or scroll window). per_page: Determine how many results are downloaded per page (or scroll window).
Default is 10. Default is 10.
max: Optionally determine the maximum amount of downloaded archives. The iteration max: Optionally determine the maximum amount of downloaded archives. The iteration
...@@ -250,12 +252,11 @@ class ArchiveQuery(collections.abc.Sequence): ...@@ -250,12 +252,11 @@ class ArchiveQuery(collections.abc.Sequence):
self, self,
query: dict = None, required: dict = None, query: dict = None, required: dict = None,
url: str = None, username: str = None, password: str = None, url: str = None, username: str = None, password: str = None,
scroll: bool = False, per_page: int = 10, max: int = None, per_page: int = 10, max: int = None,
raise_errors: bool = False, raise_errors: bool = False,
authentication: Union[Dict[str, str], KeycloakAuthenticator] = None): authentication: Union[Dict[str, str], KeycloakAuthenticator] = None):
self.scroll = scroll self._after = None
self._scroll_id = None
self.page = 1 self.page = 1
self.per_page = per_page self.per_page = per_page
self.max = max self.max = max
...@@ -326,14 +327,9 @@ class ArchiveQuery(collections.abc.Sequence): ...@@ -326,14 +327,9 @@ class ArchiveQuery(collections.abc.Sequence):
''' '''
url = '%s/%s/%s' % (self.url, 'archive', 'query') url = '%s/%s/%s' % (self.url, 'archive', 'query')
if self.scroll: aggregation = self.query.setdefault('aggregation', {'per_page': self.per_page})
scroll_config = self.query.setdefault('scroll', {'scroll': True}) if self._after is not None:
if self._scroll_id is not None: aggregation['after'] = self._after
scroll_config['scroll_id'] = self._scroll_id
else:
self.query.setdefault('pagination', {}).update(
page=self.page, per_page=self.per_page)
response = requests.post(url, headers=self.authentication, json=self.query) response = requests.post(url, headers=self.authentication, json=self.query)
if response.status_code != 200: if response.status_code != 200:
...@@ -351,15 +347,9 @@ class ArchiveQuery(collections.abc.Sequence): ...@@ -351,15 +347,9 @@ class ArchiveQuery(collections.abc.Sequence):
if not isinstance(data, dict): if not isinstance(data, dict):
data = data() data = data()
if self.scroll: aggregation = data['aggregation']
scroll = data['scroll'] self._after = aggregation.get('after')
self._scroll_id = scroll['scroll_id'] self._total = aggregation['total']
self._total = scroll['total']
else:
pagination = data['pagination']
self._total = pagination['total']
self.page = pagination['page'] + 1
if self.max is not None: if self.max is not None:
self._capped_total = min(self.max, self._total) self._capped_total = min(self.max, self._total)
...@@ -385,6 +375,11 @@ class ArchiveQuery(collections.abc.Sequence): ...@@ -385,6 +375,11 @@ class ArchiveQuery(collections.abc.Sequence):
# fails in test due to mocked requests library # fails in test due to mocked requests library
pass pass
if self._after is None:
# there are no more search results, we need to avoid further calls
self._capped_total = len(self._results)
self._total = len(self._results)
def __repr__(self): def __repr__(self):
if self._total == -1: if self._total == -1:
self.call_api() self.call_api()
...@@ -425,6 +420,20 @@ class ArchiveQuery(collections.abc.Sequence): ...@@ -425,6 +420,20 @@ class ArchiveQuery(collections.abc.Sequence):
return self._statistics return self._statistics
def clear(self, index: int = None):
'''
Remove caches results. The results are replaced with None in this object. If you
keep references to the results elsewhere, the garbage collection might not catch
those.
Arguments:
index: Remove all results upto and including the giving index. Default is to
remove all results.
'''
for i, _ in enumerate(self._results[:index]):
print(i)
self._results[i] = None
def query_archive(*args, **kwargs): def query_archive(*args, **kwargs):
return ArchiveQuery(*args, **kwargs) return ArchiveQuery(*args, **kwargs)
......
...@@ -461,7 +461,7 @@ class SearchRequest: ...@@ -461,7 +461,7 @@ class SearchRequest:
def execute(self): def execute(self):
''' '''
Exectutes without returning actual results. Only makes sense if the request Executes without returning actual results. Only makes sense if the request
was configured for statistics or quantity values. was configured for statistics or quantity values.
''' '''
search = self._search.query(self.q)[0:0] search = self._search.query(self.q)[0:0]
...@@ -581,6 +581,66 @@ class SearchRequest: ...@@ -581,6 +581,66 @@ class SearchRequest:
return dict(scroll=scroll_info, results=results) return dict(scroll=scroll_info, results=results)
def execute_aggregated(
self, after: str = None, per_page: int = 1000, includes: List[str] = None):
'''
Uses a composite aggregation on top of the search to go through the result
set. This allows to go arbirarely deep without using scroll. But, it will
only return results with ``upload_id``, ``calc_id`` and the given
quantities. The results will be 'ordered' by ``upload_id``.
Arguments:
after: The key that determines the start of the current page. This after
key is returned with each response. Use None (default) for the first
request.
per_page: The size of each page.
includes: A list of quantity names that should be returned in addition to
``upload_id`` and ``calc_id``.
'''
upload_id_agg = A('terms', field="upload_id")
calc_id_agg = A('terms', field="calc_id")
composite = dict(
sources=[dict(upload_id=upload_id_agg), dict(calc_id=calc_id_agg)],
size=per_page)
if after is not None:
upload_id, calc_id = after.split(':')
composite['after'] = dict(upload_id=upload_id, calc_id=calc_id)
composite_agg = self._search.aggs.bucket('ids', 'composite', **composite)
if includes is not None:
composite_agg.metric('examples', A('top_hits', size=1, _source=dict(includes=includes)))
search = self._search.query(self.q)[0:0]
response = search.execute()
ids = response['aggregations']['ids']
if 'after_key' in ids:
after_dict = ids['after_key']
after = '%s:%s' % (after_dict['upload_id'], after_dict['calc_id'])
else:
after = None
id_agg_info = dict(total=response['hits']['total'], after=after, per_page=per_page)
def transform_result(es_result):
result = dict(
upload_id=es_result['key']['upload_id'],
calc_id=es_result['key']['calc_id'])
if includes is not None:
source = es_result['examples']['hits']['hits'][0]['_source']
for key in source:
result[key] = source[key]
return result
results = [
transform_result(item) for item in ids['buckets']]
return dict(aggregation=id_agg_info, results=results)
def _response(self, response, with_hits: bool = False) -> Dict[str, Any]: def _response(self, response, with_hits: bool = False) -> Dict[str, Any]:
''' '''
Prepares a response object covering the total number of results, hits, statistics, Prepares a response object covering the total number of results, hits, statistics,
......
...@@ -682,12 +682,12 @@ class TestArchive(UploadFilesBasedTests): ...@@ -682,12 +682,12 @@ class TestArchive(UploadFilesBasedTests):
assert rv.status_code == 200 assert rv.status_code == 200
assert_zip_file(rv, files=1) assert_zip_file(rv, files=1)
def test_archive_query(self, api, published_wo_user_metadata): def test_archive_query_paginated(self, api, published_wo_user_metadata):
schema = { schema = {
'section_run': { 'section_run': {
'section_single_configuration_calculation': { 'section_single_configuration_calculation': {
'energy_total': '*'}}} 'energy_total': '*'}}}
data = {'results': [schema], 'per_page': 5} data = {'results': [schema], 'pagination': {'per_page': 5}}
uri = '/archive/query' uri = '/archive/query'
rv = api.post(uri, content_type='application/json', data=json.dumps(data)) rv = api.post(uri, content_type='application/json', data=json.dumps(data))
...@@ -716,6 +716,30 @@ class TestArchive(UploadFilesBasedTests): ...@@ -716,6 +716,30 @@ class TestArchive(UploadFilesBasedTests):
rv = api.post(uri, content_type='application/json', data=json.dumps(dict(per_page=5, raise_errors=False))) rv = api.post(uri, content_type='application/json', data=json.dumps(dict(per_page=5, raise_errors=False)))
assert rv.status_code == 200 assert rv.status_code == 200
def test_archive_query_aggregated(self, api, published_wo_user_metadata):
uri = '/archive/query'
schema = {
'section_run': {
'section_single_configuration_calculation': {
'energy_total': '*'}}}
query = {'results': [schema], 'aggregation': {'per_page': 1}}
count = 0
while True:
rv = api.post(uri, content_type='application/json', data=json.dumps(query))
assert rv.status_code == 200
data = rv.get_json()
results = data.get('results', None)
count += len(results)
after = data['aggregation']['after']
if after is None:
break
query['aggregation']['after'] = after
assert count > 0
class TestMetainfo(): class TestMetainfo():
@pytest.mark.parametrize('package', ['common', 'vasp', 'general.experimental', 'eels']) @pytest.mark.parametrize('package', ['common', 'vasp', 'general.experimental', 'eels'])
......
...@@ -125,6 +125,28 @@ def test_search_scroll(elastic, example_search_data): ...@@ -125,6 +125,28 @@ def test_search_scroll(elastic, example_search_data):
assert 'scroll_id' not in results['scroll'] assert 'scroll_id' not in results['scroll']
def test_search_aggregated(elastic, example_search_data):
request = SearchRequest(domain='dft')
results = request.execute_aggregated()
after = results['aggregation']['after']
assert results['aggregation']['total'] == 1
assert len(results['results']) == 1
assert 'calc_id' in results['results'][0]
assert 'upload_id' in results['results'][0]
assert after is not None
results = request.execute_aggregated(after=after)
assert results['aggregation']['total'] == 1
assert len(results['results']) == 0
assert results['aggregation']['after'] is None
def test_search_aggregated_includes(elastic, example_search_data):
request = SearchRequest(domain='dft')
results = request.execute_aggregated(includes=['with_embargo'])
assert 'with_embargo' in results['results'][0]
def test_domain(elastic, example_ems_search_data): def test_domain(elastic, example_ems_search_data):
assert len(list(SearchRequest(domain='ems').execute_scan())) > 0 assert len(list(SearchRequest(domain='ems').execute_scan())) > 0
assert len(list(SearchRequest(domain='ems').domain().execute_scan())) > 0 assert len(list(SearchRequest(domain='ems').domain().execute_scan())) > 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment