Commit 309f3da1 authored by Markus Scheidgen's avatar Markus Scheidgen
Browse files

Replaced archive API pagination with an composite aggregation approach.

parent 11df2a5e
Pipeline #75101 failed with stages
in 22 minutes and 28 seconds
......@@ -3,7 +3,7 @@ from nomad.client import ArchiveQuery
from nomad.metainfo import units
# this will not be necessary, once this is the official NOMAD version
config.client.url = 'https://labdev-nomad.esc.rzg.mpg.de/fairdi/nomad/testing-major/api'
config.client.url = 'http://labdev-nomad.esc.rzg.mpg.de/fairdi/nomad/testing-major/api'
query = ArchiveQuery(
query={
......
......@@ -233,19 +233,21 @@ class ArchiveQueryResource(Resource):
See ``/repo`` endpoint for documentation on the search
parameters.
This endpoint uses pagination (see /repo) or id aggregation to handle large result
sets over multiple requests.
Use aggregation.after and aggregation.per_page to request a
certain page with id aggregation.
The actual data are in results and a supplementary python code (curl) to
execute search is in python (curl).
'''
try:
data_in = request.get_json()
scroll = data_in.get('scroll', None)
if scroll:
scroll_id = scroll.get('scroll_id')
scroll = True
aggregation = data_in.get('aggregation', None)
pagination = data_in.get('pagination', {})
page = pagination.get('page', 1)
per_page = pagination.get('per_page', 10 if not scroll else 1000)
per_page = pagination.get('per_page', 10)
query = data_in.get('query', {})
......@@ -270,20 +272,19 @@ class ArchiveQueryResource(Resource):
search_request.owner('all')
apply_search_parameters(search_request, query)
search_request.include('calc_id', 'upload_id', 'with_embargo', 'published', 'parser_name')
if not aggregation:
search_request.include('calc_id', 'upload_id', 'with_embargo', 'published', 'parser_name')
try:
if scroll:
results = search_request.execute_scrolled(
scroll_id=scroll_id, size=per_page, order_by='upload_id')
results['scroll']['scroll'] = True
if aggregation:
results = search_request.execute_aggregated(
after=aggregation.get('after'), per_page=aggregation.get('per_page', 1000),
includes=['with_embargo', 'published', 'parser_name'])
else:
results = search_request.execute_paginated(
per_page=per_page, page=page, order_by='upload_id')
except search.ScrollIdNotFound:
abort(400, 'The given scroll_id does not exist.')
except KeyError as e:
abort(400, str(e))
......
......@@ -65,10 +65,18 @@ scroll_model = api.model('Scroll', {
'total': fields.Integer(default=0, description='The total amount of hits for the search.'),
'scroll_id': fields.String(default=None, allow_null=True, description='The scroll_id that can be used to retrieve the next page.'),
'size': fields.Integer(default=0, help='The size of the returned scroll page.')})
''' Model used in responses with scroll. '''
aggregation_model = api.model('Aggregation', {
'after': fields.String(description='The after key for the current request.', allow_null=True),
'total': fields.Integer(default=0, description='The total amount of hits for the search.'),
'per_page': fields.Integer(default=0, help='The size of the requested page.', allow_null=True)})
''' Model used in responses with id aggregation. '''
search_model_fields = {
'pagination': fields.Nested(pagination_model, allow_null=True, skip_none=True),
'scroll': fields.Nested(scroll_model, allow_null=True, skip_none=True),
'aggregation': fields.Nested(aggregation_model, allow_null=True),
'results': fields.List(fields.Raw(allow_null=True, skip_none=True), description=(
'A list of search results. Each result is a dict with quantitie names as key and '
'values as values'), allow_null=True, skip_none=True),
......
......@@ -98,6 +98,10 @@ sub-sections return lists of further objects. Here we navigate the sections ``se
sub-section ``section_system`` to access the quantity ``energy_total``. This quantity is a
number with an attached unit (Joule), which can be converted to something else (e.g. Hartree).
The create query object keeps all results in memory. Keep this in mind, when you are
accessing a large amount of query results. You should use :func:`ArchiveQuery.clear`
to remove unnecessary results.
The NOMAD Metainfo
__________________
......@@ -234,8 +238,6 @@ class ArchiveQuery(collections.abc.Sequence):
url: Optional, override the default NOMAD API url.
username: Optional, allows authenticated access.
password: Optional, allows authenticated access.
scroll: Use the scroll API to iterate through results. This is required when you
are accessing many 1000 results. By default, the pagination API is used.
per_page: Determine how many results are downloaded per page (or scroll window).
Default is 10.
max: Optionally determine the maximum amount of downloaded archives. The iteration
......@@ -250,12 +252,11 @@ class ArchiveQuery(collections.abc.Sequence):
self,
query: dict = None, required: dict = None,
url: str = None, username: str = None, password: str = None,
scroll: bool = False, per_page: int = 10, max: int = None,
per_page: int = 10, max: int = None,
raise_errors: bool = False,
authentication: Union[Dict[str, str], KeycloakAuthenticator] = None):
self.scroll = scroll
self._scroll_id = None
self._after = None
self.page = 1
self.per_page = per_page
self.max = max
......@@ -326,14 +327,9 @@ class ArchiveQuery(collections.abc.Sequence):
'''
url = '%s/%s/%s' % (self.url, 'archive', 'query')
if self.scroll:
scroll_config = self.query.setdefault('scroll', {'scroll': True})
if self._scroll_id is not None:
scroll_config['scroll_id'] = self._scroll_id
else:
self.query.setdefault('pagination', {}).update(
page=self.page, per_page=self.per_page)
aggregation = self.query.setdefault('aggregation', {'per_page': self.per_page})
if self._after is not None:
aggregation['after'] = self._after
response = requests.post(url, headers=self.authentication, json=self.query)
if response.status_code != 200:
......@@ -351,15 +347,9 @@ class ArchiveQuery(collections.abc.Sequence):
if not isinstance(data, dict):
data = data()
if self.scroll:
scroll = data['scroll']
self._scroll_id = scroll['scroll_id']
self._total = scroll['total']
else:
pagination = data['pagination']
self._total = pagination['total']
self.page = pagination['page'] + 1
aggregation = data['aggregation']
self._after = aggregation.get('after')
self._total = aggregation['total']
if self.max is not None:
self._capped_total = min(self.max, self._total)
......@@ -385,6 +375,11 @@ class ArchiveQuery(collections.abc.Sequence):
# fails in test due to mocked requests library
pass
if self._after is None:
# there are no more search results, we need to avoid further calls
self._capped_total = len(self._results)
self._total = len(self._results)
def __repr__(self):
if self._total == -1:
self.call_api()
......@@ -425,6 +420,20 @@ class ArchiveQuery(collections.abc.Sequence):
return self._statistics
def clear(self, index: int = None):
'''
Remove caches results. The results are replaced with None in this object. If you
keep references to the results elsewhere, the garbage collection might not catch
those.
Arguments:
index: Remove all results upto and including the giving index. Default is to
remove all results.
'''
for i, _ in enumerate(self._results[:index]):
print(i)
self._results[i] = None
def query_archive(*args, **kwargs):
return ArchiveQuery(*args, **kwargs)
......
......@@ -461,7 +461,7 @@ class SearchRequest:
def execute(self):
'''
Exectutes without returning actual results. Only makes sense if the request
Executes without returning actual results. Only makes sense if the request
was configured for statistics or quantity values.
'''
search = self._search.query(self.q)[0:0]
......@@ -581,6 +581,66 @@ class SearchRequest:
return dict(scroll=scroll_info, results=results)
def execute_aggregated(
self, after: str = None, per_page: int = 1000, includes: List[str] = None):
'''
Uses a composite aggregation on top of the search to go through the result
set. This allows to go arbirarely deep without using scroll. But, it will
only return results with ``upload_id``, ``calc_id`` and the given
quantities. The results will be 'ordered' by ``upload_id``.
Arguments:
after: The key that determines the start of the current page. This after
key is returned with each response. Use None (default) for the first
request.
per_page: The size of each page.
includes: A list of quantity names that should be returned in addition to
``upload_id`` and ``calc_id``.
'''
upload_id_agg = A('terms', field="upload_id")
calc_id_agg = A('terms', field="calc_id")
composite = dict(
sources=[dict(upload_id=upload_id_agg), dict(calc_id=calc_id_agg)],
size=per_page)
if after is not None:
upload_id, calc_id = after.split(':')
composite['after'] = dict(upload_id=upload_id, calc_id=calc_id)
composite_agg = self._search.aggs.bucket('ids', 'composite', **composite)
if includes is not None:
composite_agg.metric('examples', A('top_hits', size=1, _source=dict(includes=includes)))
search = self._search.query(self.q)[0:0]
response = search.execute()
ids = response['aggregations']['ids']
if 'after_key' in ids:
after_dict = ids['after_key']
after = '%s:%s' % (after_dict['upload_id'], after_dict['calc_id'])
else:
after = None
id_agg_info = dict(total=response['hits']['total'], after=after, per_page=per_page)
def transform_result(es_result):
result = dict(
upload_id=es_result['key']['upload_id'],
calc_id=es_result['key']['calc_id'])
if includes is not None:
source = es_result['examples']['hits']['hits'][0]['_source']
for key in source:
result[key] = source[key]
return result
results = [
transform_result(item) for item in ids['buckets']]
return dict(aggregation=id_agg_info, results=results)
def _response(self, response, with_hits: bool = False) -> Dict[str, Any]:
'''
Prepares a response object covering the total number of results, hits, statistics,
......
......@@ -682,12 +682,12 @@ class TestArchive(UploadFilesBasedTests):
assert rv.status_code == 200
assert_zip_file(rv, files=1)
def test_archive_query(self, api, published_wo_user_metadata):
def test_archive_query_paginated(self, api, published_wo_user_metadata):
schema = {
'section_run': {
'section_single_configuration_calculation': {
'energy_total': '*'}}}
data = {'results': [schema], 'per_page': 5}
data = {'results': [schema], 'pagination': {'per_page': 5}}
uri = '/archive/query'
rv = api.post(uri, content_type='application/json', data=json.dumps(data))
......@@ -716,6 +716,30 @@ class TestArchive(UploadFilesBasedTests):
rv = api.post(uri, content_type='application/json', data=json.dumps(dict(per_page=5, raise_errors=False)))
assert rv.status_code == 200
def test_archive_query_aggregated(self, api, published_wo_user_metadata):
uri = '/archive/query'
schema = {
'section_run': {
'section_single_configuration_calculation': {
'energy_total': '*'}}}
query = {'results': [schema], 'aggregation': {'per_page': 1}}
count = 0
while True:
rv = api.post(uri, content_type='application/json', data=json.dumps(query))
assert rv.status_code == 200
data = rv.get_json()
results = data.get('results', None)
count += len(results)
after = data['aggregation']['after']
if after is None:
break
query['aggregation']['after'] = after
assert count > 0
class TestMetainfo():
@pytest.mark.parametrize('package', ['common', 'vasp', 'general.experimental', 'eels'])
......
......@@ -125,6 +125,28 @@ def test_search_scroll(elastic, example_search_data):
assert 'scroll_id' not in results['scroll']
def test_search_aggregated(elastic, example_search_data):
request = SearchRequest(domain='dft')
results = request.execute_aggregated()
after = results['aggregation']['after']
assert results['aggregation']['total'] == 1
assert len(results['results']) == 1
assert 'calc_id' in results['results'][0]
assert 'upload_id' in results['results'][0]
assert after is not None
results = request.execute_aggregated(after=after)
assert results['aggregation']['total'] == 1
assert len(results['results']) == 0
assert results['aggregation']['after'] is None
def test_search_aggregated_includes(elastic, example_search_data):
request = SearchRequest(domain='dft')
results = request.execute_aggregated(includes=['with_embargo'])
assert 'with_embargo' in results['results'][0]
def test_domain(elastic, example_ems_search_data):
assert len(list(SearchRequest(domain='ems').execute_scan())) > 0
assert len(list(SearchRequest(domain='ems').domain().execute_scan())) > 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment