diff --git a/nomad/api/common.py b/nomad/api/common.py index 9e65daeefc2432b13fa866e11b343e5193e3e87b..055f717f50872cd7a04cff719b0eee303b3ca522 100644 --- a/nomad/api/common.py +++ b/nomad/api/common.py @@ -24,7 +24,7 @@ from .app import api pagination_model = api.model('Pagination', { 'total': fields.Integer(description='Number of total elements.'), 'page': fields.Integer(description='Number of the current page, starting with 0.'), - 'per_page': fields.Integer(description='Number of elements per page.'), + 'per_page': fields.Integer(description='Number of elements per page.') }) """ Model used in responsed with pagination. """ diff --git a/nomad/api/repo.py b/nomad/api/repo.py index ecf57ccc034dfc88339e322f81675242c9193293..5f3d3f265aaf96b6272d3625ad3fd53eee212859 100644 --- a/nomad/api/repo.py +++ b/nomad/api/repo.py @@ -64,6 +64,7 @@ repo_calcs_model = api.model('RepoCalculations', { 'results': fields.List(fields.Raw, description=( 'A list of search results. Each result is a dict with quantitie names as key and ' 'values as values')), + 'scroll_id': fields.String(description='Id of the current scroll view in scroll based search.'), 'aggregations': fields.Raw(description=( 'A dict with all aggregations. Each aggregation is dictionary with the amount as ' 'value and quantity value as key.')) @@ -73,6 +74,10 @@ repo_request_parser = pagination_request_parser.copy() repo_request_parser.add_argument( 'owner', type=str, help='Specify which calcs to return: ``all``, ``user``, ``staging``, default is ``all``') +repo_request_parser.add_argument( + 'scroll', type=bool, help='Enable scrolling') +repo_request_parser.add_argument( + 'scroll_id', type=str, help='The id of the current scrolling window to use.') for search_quantity in search.search_quantities.keys(): _, _, description = search.search_quantities[search_quantity] @@ -96,11 +101,27 @@ class RepoCalcsResource(Resource): you will be given a list of all possible values and the number of entries that have the certain value. You can also use these aggregations on an empty search to determine the possible values. + + The pagination parameters allows determine which page to return via the + ``page`` and ``per_page`` parameters. Pagination however, is limited to the first + 100k (depending on ES configuration) hits. An alternative to pagination is to use + ``scroll`` and ``scroll_id``. With ``scroll`` you will get a ``scroll_id`` on + the first request. Each call with ``scroll`` and the respective ``scroll_id`` will + return the next ``per_page`` (here the default is 1000) results. Scroll however, + ignores ordering and does not return aggregations. The scroll view used in the + background will stay alive for 1 minute between requests. + + The search will return aggregations on a predefined set of quantities. Aggregations + will tell you what quantity values exist and how many entries match those values. + + Ordering is determined by ``order_by`` and ``order`` parameters. """ try: + scroll = bool(request.args.get('scroll', False)) + scroll_id = request.args.get('scroll_id', None) page = int(request.args.get('page', 1)) - per_page = int(request.args.get('per_page', 10)) + per_page = int(request.args.get('per_page', 10 if not scroll else 1000)) order = int(request.args.get('order', -1)) except Exception: abort(400, message='bad parameter types') @@ -136,14 +157,31 @@ class RepoCalcsResource(Resource): data = dict(**request.args) data.pop('owner', None) - data.update(per_page=per_page, page=page, order=order, order_by=order_by) + data.pop('scroll', None) + data.pop('scroll_id', None) + data.pop('per_page', None) + data.pop('page', None) + data.pop('order', None) + data.pop('order_by', None) + + if scroll: + data.update(scroll_id=scroll_id, size=per_page) + else: + data.update(per_page=per_page, page=page, order=order, order_by=order_by) try: - total, results, aggregations = search.aggregate_search(q=q, **data) + if scroll: + page = -1 + scroll_id, total, results = search.scroll_search(q=q, **data) + aggregations = None + else: + scroll_id = None + total, results, aggregations = search.aggregate_search(q=q, **data) except KeyError as e: abort(400, str(e)) return dict( pagination=dict(total=total, page=page, per_page=per_page), results=results, + scroll_id=scroll_id, aggregations=aggregations), 200 diff --git a/nomad/migration.py b/nomad/migration.py index 5b6bb77b9bc8e97a9c0f08045156d4088f2bc6dd..a8089b97b3efb87d532ac10e6339ee94b62d9521 100644 --- a/nomad/migration.py +++ b/nomad/migration.py @@ -792,10 +792,15 @@ class NomadCOEMigration: # verify upload against source calcs_in_search = 0 with utils.timer(logger, 'varyfied upload against source calcs'): - for page in range(1, math.ceil(upload_total_calcs / per_page) + 1): - search = self.nomad( - 'repo.search', page=page, per_page=per_page, upload_id=upload.upload_id, - order_by='mainfile') + scroll_id = 'first' + while scroll_id is not None: + scroll_args: Dict[str, Any] = dict(scroll=True) + if scroll_id != 'first': + scroll_args['scroll_id'] = scroll_id + + search = self.nomad('repo.search', upload_id=upload.upload_id, **scroll_args) + + scroll_id = search.scroll_id for calc in search.results: calcs_in_search += 1 diff --git a/nomad/search.py b/nomad/search.py index 3500808c0b63af7fcede52bcd8ffc1505cdfabc5..ecd902e56a03bd55df1f60af2b1f864891ebd5be 100644 --- a/nomad/search.py +++ b/nomad/search.py @@ -32,6 +32,9 @@ path_analyzer = analyzer( class AlreadyExists(Exception): pass +class ElasticSearchError(Exception): pass + + class User(InnerDoc): @classmethod @@ -254,25 +257,7 @@ elastic field and description. """ -def aggregate_search( - page: int = 1, per_page: int = 10, order_by: str = 'formula', order: int = -1, - q: Q = None, **kwargs) -> Tuple[int, List[dict], Dict[str, Dict[str, int]]]: - """ - Performs a search and returns paginated search results and aggregation bucket sizes - based on key quantities. - - Arguments: - page: The page to return starting with page 1 - per_page: Results per page - q: An *elasticsearch_dsl* query used to further filter the results (via `and`) - aggregations: A customized list of aggregations to perform. Keys are index fields, - and values the amount of buckets to return. Only works on *keyword* field. - **kwargs: Quantity, value pairs to search for. - - Returns: A tuple with the total hits, an array with the results, an dictionary with - the aggregation data. - """ - +def _construct_search(q: Q = None, **kwargs) -> Search: search = Search(index=config.elastic.index_name) if q is not None: @@ -296,6 +281,82 @@ def aggregate_search( for item in items: search = search.query(Q(query_type, **{field: item})) + search = search.source(exclude=['quantities']) + + return search + + +def scroll_search( + scroll_id: str = None, size: int = 1000, scroll: str = u'5m', + q: Q = None, **kwargs) -> Tuple[str, int, List[dict]]: + """ + Alternative search based on ES scroll API. Can be used similar to + :func:`aggregate_search`, but pagination is replaced with scrolling, no ordering, + and no aggregation information is given. + + Scrolling is done by calling this function again and again with the same ``scoll_id``. + Each time, this function will return the next batch of search results. + + See see :func:`aggregate_search` for additional ``kwargs`` + + Arguments: + scroll_id: The scroll id to receive the next batch from. None will create a new + scroll. + size: The batch size in number of hits. + scroll: The time the scroll should be kept alive (i.e. the time between requests + to this method) in ES time units. Default is 5 minutes. + """ + es = infrastructure.elastic_client + + if scroll_id is None: + # initiate scroll + search = _construct_search(q, **kwargs) + resp = es.search(body=search.to_dict(), scroll=scroll, size=size) # pylint: disable=E1123 + + scroll_id = resp.get('_scroll_id') + if scroll_id is None: + # no results for search query + return None, 0, [] + else: + resp = es.scroll(scroll_id, scroll=scroll) # pylint: disable=E1123 + + total = resp['hits']['total'] + results = [hit['_source'] for hit in resp['hits']['hits']] + + # since we are using the low level api here, we should check errors + if resp["_shards"]["successful"] < resp["_shards"]["total"]: + utils.get_logger(__name__).error('es operation was unsuccessful on at least one shard') + raise ElasticSearchError('es operation was unsuccessful on at least one shard') + + if len(results) == 0: + es.clear_scroll(body={'scroll_id': [scroll_id]}, ignore=(404, )) # pylint: disable=E1123 + return None, total, [] + + return scroll_id, total, results + + +def aggregate_search( + page: int = 1, per_page: int = 10, order_by: str = 'formula', order: int = -1, + q: Q = None, aggregations: Dict[str, int] = aggregations, + **kwargs) -> Tuple[int, List[dict], Dict[str, Dict[str, int]]]: + """ + Performs a search and returns paginated search results and aggregation bucket sizes + based on key quantities. + + Arguments: + page: The page to return starting with page 1 + per_page: Results per page + q: An *elasticsearch_dsl* query used to further filter the results (via `and`) + aggregations: A customized list of aggregations to perform. Keys are index fields, + and values the amount of buckets to return. Only works on *keyword* field. + **kwargs: Quantity, value pairs to search for. + + Returns: A tuple with the total hits, an array with the results, an dictionary with + the aggregation data. + """ + + search = _construct_search(q, **kwargs) + for aggregation, size in aggregations.items(): if aggregation == 'authors': search.aggs.bucket(aggregation, A('terms', field='authors.name_keyword', size=size)) @@ -306,8 +367,6 @@ def aggregate_search( raise KeyError('Unknown order quantity %s' % order_by) search = search.sort(order_by if order == 1 else '-%s' % order_by) - search = search.source(exclude=['quantities']) - response = search[(page - 1) * per_page: page * per_page].execute() # pylint: disable=E1101 total_results = response.hits.total diff --git a/tests/test_api.py b/tests/test_api.py index 17c504d7ff1eb8e7955983918bb03813b9145b26..c8f4a20b0354186cbf480f850b6ec81f9a45039a 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -656,6 +656,32 @@ class TestRepo(UploadFilesBasedTests): assert len(results) == 2 assert results[0]['calc_id'] == first + @pytest.mark.parametrize('n_results, size', [(2, None), (2, 5), (1, 1)]) + def test_search_scroll(self, client, example_elastic_calcs, no_warn, n_results, size): + if size is not None: + rv = client.get('/repo/?scroll=1,&per_page=%d' % size) + else: + rv = client.get('/repo/?scroll=1') + + assert rv.status_code == 200 + data = json.loads(rv.data) + results = data.get('results', None) + assert data['pagination']['total'] == 2 + assert results is not None + assert len(results) == n_results + scroll_id = data.get('scroll_id', None) + assert scroll_id is not None + + has_another_page = False + while scroll_id is not None: + rv = client.get('/repo/?scroll=1&scroll_id=%s' % scroll_id) + data = json.loads(rv.data) + scroll_id = data.get('scroll_id', None) + has_another_page |= len(data.get('results')) > 0 + + if n_results < 2: + assert has_another_page + def test_search_user_authrequired(self, client, example_elastic_calcs, no_warn): rv = client.get('/repo/?owner=user') assert rv.status_code == 401 diff --git a/tests/test_search.py b/tests/test_search.py index 06da8cc96de2122bfcd57e8dc30254ddcebe1de6..2ef2d7fb2dd5dedc42fd58dd9e205a6ec689f4ab 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -15,7 +15,7 @@ from elasticsearch_dsl import Q from nomad import datamodel, search, processing, parsing, infrastructure, config, coe_repo -from nomad.search import Entry, aggregate_search, authors +from nomad.search import Entry, aggregate_search, authors, scroll_search def test_init_mapping(elastic): @@ -60,6 +60,22 @@ def test_search(elastic, normalized: parsing.LocalBackend): assert 'quantities' not in hits[0] +def test_scroll(elastic, normalized: parsing.LocalBackend): + calc_with_metadata = normalized.to_calc_with_metadata() + create_entry(calc_with_metadata) + refresh_index() + + scroll_id, total, hits = scroll_search() + assert total == 1 + assert len(hits) == 1 + assert scroll_id is not None + + scroll_id, total, hits = scroll_search(scroll_id=scroll_id) + assert total == 1 + assert scroll_id is None + assert len(hits) == 0 + + def test_authors(elastic, normalized: parsing.LocalBackend, test_user: coe_repo.User, other_test_user: coe_repo.User): calc_with_metadata = normalized.to_calc_with_metadata() calc_with_metadata.uploader = test_user.to_popo()