From 6ad17658e6261678541de9c55b6758c97a187d03 Mon Sep 17 00:00:00 2001 From: Markus Scheidgen <markus.scheidgen@gmail.com> Date: Wed, 3 Nov 2021 20:55:16 +0100 Subject: [PATCH] Added exclude from search option to aggregations. #573, #575 --- nomad/app/v1/models.py | 13 +++ nomad/search.py | 108 +++++++++++++++++---- tests/app/v1/routers/common.py | 128 +++++++++++++++++++++++++ tests/app/v1/routers/test_entries.py | 25 ++++- tests/app/v1/routers/test_materials.py | 23 ++++- 5 files changed, 275 insertions(+), 22 deletions(-) diff --git a/nomad/app/v1/models.py b/nomad/app/v1/models.py index 9f00da6abc..c56493e0e9 100644 --- a/nomad/app/v1/models.py +++ b/nomad/app/v1/models.py @@ -788,6 +788,19 @@ class QuantityAggregation(AggregationBase): The manatory name of the quantity for the aggregation. Aggregations can only be computed for those search metadata that have discrete values; an aggregation buckets entries that have the same value for this quantity.''')) + exclude_from_search: bool = Field( + False, description=strip(''' + If set to true, top-level search criteria involving the aggregation quantity, will not + be applied for this aggregation. Therefore, the aggregation will return all + values for the quantity, even if the possible values where filtered by the query. + + There are two limitations. This is only supported with queries that start with a + dictionary. It will not work for queries with a boolean operator. It can only + exclude top-level criteria at the root of the query dictionary. Nested criteria, + e.g. within complex and/or constructs, cannot be considered. Using this might also + prohibit pagination with page_after_value on aggregations in the same request. + ''') + ) class BucketAggregation(QuantityAggregation): diff --git a/nomad/search.py b/nomad/search.py index c5607acc02..f78ab5c15d 100644 --- a/nomad/search.py +++ b/nomad/search.py @@ -437,9 +437,20 @@ def validate_quantity( return quantity +def _create_es_must(queries: Dict[str, EsQuery]): + # dictionary is like an "and" of all items in the dict + if len(queries) == 0: + return Q() + + if len(queries) == 1: + return list(queries.values())[0] + + return Q('bool', must=list(queries.values())) + + def validate_api_query( query: Query, doc_type: DocumentType, owner_query: EsQuery, - prefix: str = None) -> EsQuery: + prefix: str = None, results_dict: Dict[str, EsQuery] = None) -> EsQuery: ''' Creates an ES query based on the API's query model. This needs to be a normalized query expression with explicit objects for logical, set, and comparison operators. @@ -460,6 +471,11 @@ def validate_api_query( materials queries. prefix: An optional prefix that is added to all quantity names. Used for recursion. + results_dict: + If an empty dictionary is given and the query is a mapping, the top-level + criteria from this mapping will be added as individual es queries. The + keys will be the mapping keys and values the respective es queries. A logical + and (or es "must") would result in the overall resulting es query. Returns: A elasticsearch dsl query object. @@ -564,11 +580,20 @@ def validate_api_query( return Q() if len(query) == 1: - key = next(iter(query)) - return validate_criteria(key, query[key]) + name = next(iter(query)) + es_criteria_query = validate_criteria(name, query[name]) + if results_dict is not None: + results_dict[name] = es_criteria_query + return es_criteria_query - return Q('bool', must=[ - validate_criteria(name, value) for name, value in query.items()]) + es_criteria_queries = [] + for name, value in query.items(): + es_criteria_query = validate_criteria(name, value) + es_criteria_queries.append(es_criteria_query) + if results_dict is not None: + results_dict[name] = es_criteria_query + + return Q('bool', must=es_criteria_queries) raise NotImplementedError() @@ -595,7 +620,8 @@ def validate_pagination(pagination: Pagination, doc_type: DocumentType, loc: Lis def _api_to_es_aggregation( - es_search: Search, name: str, agg: AggregationBase, doc_type: DocumentType) -> A: + es_search: Search, name: str, agg: AggregationBase, doc_type: DocumentType, + post_agg_queries: Dict[str, EsQuery]) -> A: ''' Creates an ES aggregation based on the API's aggregation model. ''' @@ -603,6 +629,12 @@ def _api_to_es_aggregation( agg_name = f'agg:{name}' es_aggs = es_search.aggs + if post_agg_queries: + filter = post_agg_queries + if isinstance(agg, QuantityAggregation) and agg.exclude_from_search: + filter = {name: query for name, query in post_agg_queries.items() if name != agg.quantity} + es_aggs = es_aggs.bucket(f'{agg_name}:filtered', A('filter', filter=_create_es_must(filter))) + if isinstance(agg, StatisticsAggregation): for metric_name in agg.metrics: metrics = doc_type.metrics @@ -620,13 +652,11 @@ def _api_to_es_aggregation( return agg = cast(QuantityAggregation, agg) - longest_nested_key = None quantity = validate_quantity(agg.quantity, doc_type=doc_type, loc=['aggregation', 'quantity']) - for nested_key in doc_type.nested_object_keys: if agg.quantity.startswith(nested_key): - es_aggs = es_search.aggs.bucket('nested_agg:%s' % name, 'nested', path=nested_key) + es_aggs = es_aggs.bucket('nested_agg:%s' % name, 'nested', path=nested_key) longest_nested_key = nested_key es_agg = None @@ -674,6 +704,11 @@ def _api_to_es_aggregation( } if page_after_value is not None: + if post_agg_queries: + raise QueryValidationError( + f'aggregation page_after_value cannot be used with exclude_from_search in the same request', + loc=['aggregations', name, 'terms', 'pagination', 'page_after_value']) + if order_quantity is None: composite['after'] = {name: page_after_value} else: @@ -770,6 +805,11 @@ def _es_to_api_aggregation( the given aggregation. ''' es_aggs = es_response.aggs + + filtered_agg_name = f'agg:{name}:filtered' + if filtered_agg_name in es_response.aggs: + es_aggs = es_aggs[f'agg:{name}:filtered'] + aggregation_dict = agg.dict(by_alias=True) if isinstance(agg, StatisticsAggregation): @@ -785,7 +825,7 @@ def _es_to_api_aggregation( longest_nested_key = None for nested_key in doc_type.nested_object_keys: if agg.quantity.startswith(nested_key): - es_aggs = es_response.aggs[f'nested_agg:{name}'] + es_aggs = es_aggs[f'nested_agg:{name}'] longest_nested_key = nested_key has_no_pagination = getattr(agg, 'pagination', None) is None @@ -907,22 +947,24 @@ def search( doc_type = index.doc_type - # owner and query + # owner owner_query = _owner_es_query(owner=owner, user_id=user_id, doc_type=doc_type) + # query if query is None: query = {} + es_query_dict: Dict[str, EsQuery] = {} if isinstance(query, EsQuery): es_query = cast(EsQuery, query) else: es_query = validate_api_query( - cast(Query, query), doc_type=doc_type, owner_query=owner_query) + cast(Query, query), doc_type=doc_type, owner_query=owner_query, + results_dict=es_query_dict) if doc_type != entry_type: - es_query &= Q('nested', path='entries', query=owner_query) - else: - es_query &= owner_query + owner_query = Q('nested', path='entries', query=owner_query) + es_query &= owner_query # pagination if pagination is None: @@ -933,7 +975,6 @@ def search( search = Search(index=index.index_name) - search = search.query(es_query) # TODO this depends on doc_type if pagination.order_by is None: pagination.order_by = doc_type.id_field @@ -974,15 +1015,44 @@ def search( search = search.source(includes=required.include, excludes=required.exclude) # aggregations - for name, agg in aggregations.items(): - _api_to_es_aggregation(search, name, _specific_agg(agg), doc_type=doc_type) + aggs = [(name, _specific_agg(agg)) for name, agg in aggregations.items()] + post_agg_queries: Dict[str, EsQuery] = {} + excluded_agg_quantities = { + agg.quantity + for _, agg in aggs + if isinstance(agg, QuantityAggregation) and agg.exclude_from_search} + + if len(excluded_agg_quantities) > 0: + if not isinstance(query, dict): + # "exclude_from_search" only work for toplevel mapping queries + raise QueryValidationError( + f'the query has to be a dictionary if there is an aggregation with exclude_from_search', + loc=['query']) + + pre_agg_queries = { + quantity: es_query + for quantity, es_query in es_query_dict.items() + if quantity not in excluded_agg_quantities} + post_agg_queries = { + quantity: es_query + for quantity, es_query in es_query_dict.items() + if quantity in excluded_agg_quantities} + + search = search.post_filter(_create_es_must(post_agg_queries)) + search = search.query(_create_es_must(pre_agg_queries) & owner_query) + + else: + search = search.query(es_query) # pylint: disable=no-member + + for name, agg in aggs: + _api_to_es_aggregation( + search, name, agg, doc_type=doc_type, post_agg_queries=post_agg_queries) # execute try: es_response = search.execute() except RequestError as e: raise SearchError(e) - more_response_data = {} # pagination diff --git a/tests/app/v1/routers/common.py b/tests/app/v1/routers/common.py index b3b69ea0e8..0916c7b6fe 100644 --- a/tests/app/v1/routers/common.py +++ b/tests/app/v1/routers/common.py @@ -333,6 +333,134 @@ def aggregation_test_parameters(entity_id: str, material_prefix: str, entry_pref ] +def aggregation_exclude_from_search_test_parameters(entry_prefix: str, total_per_entity: int, total: int): + entry_id = f'{entry_prefix}entry_id' + upload_id = f'{entry_prefix}upload_id' + program_name = f'{entry_prefix}results.method.simulation.program_name' + + return [ + pytest.param( + { + f'{entry_id}:any': ['id_01'] + }, + [ + { + 'exclude_from_search': True, + 'quantity': entry_id + } + ], + [10], 1, 200, + id='exclude' + ), + pytest.param( + { + f'{entry_id}:any': ['id_01'] + }, + [ + { + 'exclude_from_search': False, + 'quantity': entry_id + } + ], + [total_per_entity], 1, 200, + id='dont-exclude' + ), + pytest.param( + { + f'{entry_id}:any': ['id_01'], + upload_id: 'id_published', + program_name: 'VASP' + }, + [ + { + 'exclude_from_search': True, + 'quantity': entry_id + }, + { + 'exclude_from_search': True, + 'quantity': upload_id + } + ], + [10, 1], 1, 200, + id='two-aggs' + ), + pytest.param( + { + f'{entry_id}:any': ['id_01'] + }, + [ + { + 'exclude_from_search': True, + 'quantity': entry_id + }, + { + 'exclude_from_search': False, + 'quantity': entry_id + } + ], + [10, total_per_entity], 1, 200, + id='two-aggs-same-quantity' + ), + pytest.param( + {}, + [ + { + 'exclude_from_search': True, + 'quantity': entry_id + } + ], + [10], total, 200, + id='not-in-query' + ), + pytest.param( + {}, + [ + { + 'exclude_from_search': True, + 'quantity': entry_id, + 'pagination': { + 'page_size': 20 + } + } + ], + [20], total, 200, + id='with-pagination' + ), + pytest.param( + { + 'or': [{entry_id: 'id_01'}] + }, + [ + { + 'exclude_from_search': True, + 'quantity': entry_id + } + ], + [0], 0, 422, + id='non-dict-query' + ), + pytest.param( + { + f'{entry_id}:any': ['id_01'] + }, + [ + { + 'exclude_from_search': True, + 'quantity': entry_id + }, + { + 'quantity': entry_id, + 'pagination': { + 'page_after_value': 'id_published' + } + } + ], + [0], 0, 422, + id='with-page-after-value' + ) + ] + + def assert_response(response, status_code=None): ''' General assertions for status_code and error messages ''' if status_code and response.status_code != status_code: diff --git a/tests/app/v1/routers/test_entries.py b/tests/app/v1/routers/test_entries.py index 4a2dde5fd3..0e1d777c4a 100644 --- a/tests/app/v1/routers/test_entries.py +++ b/tests/app/v1/routers/test_entries.py @@ -29,8 +29,8 @@ from tests.utils import ExampleData from tests.test_files import example_mainfile_contents, append_raw_files # pylint: disable=unused-import from .common import ( - assert_response, assert_base_metadata_response, assert_metadata_response, - assert_required, assert_aggregations, assert_pagination, + aggregation_exclude_from_search_test_parameters, assert_response, assert_base_metadata_response, + assert_metadata_response, assert_required, assert_aggregations, assert_pagination, perform_metadata_test, post_query_test_parameters, get_query_test_parameters, perform_owner_test, owner_test_parameters, pagination_test_parameters, aggregation_test_parameters) @@ -368,6 +368,27 @@ def test_entries_aggregations(client, data, test_user_auth, aggregation, total, default_key='entry_id') +@pytest.mark.parametrize( + 'query,aggs,agg_lengths,total,status_code', + aggregation_exclude_from_search_test_parameters(entry_prefix='', total_per_entity=1, total=23)) +def test_entries_aggregations_exclude_from_search(client, data, query, aggs, agg_lengths, total, status_code): + aggs = {f'agg_{i}': {'terms': agg} for i, agg in enumerate(aggs)} + + response_json = perform_entries_metadata_test( + client, owner='visible', + query=query, aggregations=aggs, + pagination=dict(page_size=0), + status_code=status_code, http_method='post') + + if response_json is None: + return + + assert response_json['pagination']['total'] == total + for i, length in enumerate(agg_lengths): + response_agg = response_json['aggregations'][f'agg_{i}']['terms'] + assert len(response_agg['data']) == length + + @pytest.mark.parametrize('required, status_code', [ pytest.param({'include': ['entry_id', 'upload_id']}, 200, id='include'), pytest.param({'include': ['results.*', 'upload_id']}, 200, id='include-section'), diff --git a/tests/app/v1/routers/test_materials.py b/tests/app/v1/routers/test_materials.py index 32ed5fb514..970e68b09e 100644 --- a/tests/app/v1/routers/test_materials.py +++ b/tests/app/v1/routers/test_materials.py @@ -24,7 +24,7 @@ from nomad.metainfo.elasticsearch_extension import material_entry_type from tests.test_files import example_mainfile_contents # pylint: disable=unused-import from .common import ( - assert_pagination, assert_metadata_response, assert_required, assert_aggregations, + aggregation_exclude_from_search_test_parameters, assert_pagination, assert_metadata_response, assert_required, assert_aggregations, perform_metadata_test, perform_owner_test, owner_test_parameters, post_query_test_parameters, get_query_test_parameters, pagination_test_parameters, aggregation_test_parameters) @@ -74,6 +74,27 @@ def test_materials_aggregations(client, data, test_user_auth, aggregation, total default_key='material_id') +@pytest.mark.parametrize( + 'query,aggs,agg_lengths,total,status_code', + aggregation_exclude_from_search_test_parameters(entry_prefix='entries.', total_per_entity=3, total=6)) +def test_materials_aggregations_exclude_from_search(client, data, query, aggs, agg_lengths, total, status_code): + aggs = {f'agg_{i}': {'terms': agg} for i, agg in enumerate(aggs)} + + response_json = perform_materials_metadata_test( + client, owner='visible', + query=query, aggregations=aggs, + pagination=dict(page_size=0), + status_code=status_code, http_method='post') + + if response_json is None: + return + + assert response_json['pagination']['total'] == total + for i, length in enumerate(agg_lengths): + response_agg = response_json['aggregations'][f'agg_{i}']['terms'] + assert len(response_agg['data']) == length + + @pytest.mark.parametrize('required, status_code', [ pytest.param({'include': ['material_id', program_name]}, 200, id='include'), pytest.param({'include': ['entries.*', program_name]}, 200, id='include-section'), -- GitLab