From 6ad17658e6261678541de9c55b6758c97a187d03 Mon Sep 17 00:00:00 2001
From: Markus Scheidgen <markus.scheidgen@gmail.com>
Date: Wed, 3 Nov 2021 20:55:16 +0100
Subject: [PATCH] Added exclude from search option to aggregations. #573, #575

---
 nomad/app/v1/models.py                 |  13 +++
 nomad/search.py                        | 108 +++++++++++++++++----
 tests/app/v1/routers/common.py         | 128 +++++++++++++++++++++++++
 tests/app/v1/routers/test_entries.py   |  25 ++++-
 tests/app/v1/routers/test_materials.py |  23 ++++-
 5 files changed, 275 insertions(+), 22 deletions(-)

diff --git a/nomad/app/v1/models.py b/nomad/app/v1/models.py
index 9f00da6abc..c56493e0e9 100644
--- a/nomad/app/v1/models.py
+++ b/nomad/app/v1/models.py
@@ -788,6 +788,19 @@ class QuantityAggregation(AggregationBase):
         The manatory name of the quantity for the aggregation. Aggregations
         can only be computed for those search metadata that have discrete values;
         an aggregation buckets entries that have the same value for this quantity.'''))
+    exclude_from_search: bool = Field(
+        False, description=strip('''
+        If set to true, top-level search criteria involving the aggregation quantity, will not
+        be applied for this aggregation. Therefore, the aggregation will return all
+        values for the quantity, even if the possible values where filtered by the query.
+
+        There are two limitations. This is only supported with queries that start with a
+        dictionary. It will not work for queries with a boolean operator. It can only
+        exclude top-level criteria at the root of the query dictionary. Nested criteria,
+        e.g. within complex and/or constructs, cannot be considered. Using this might also
+        prohibit pagination with page_after_value on aggregations in the same request.
+        ''')
+    )
 
 
 class BucketAggregation(QuantityAggregation):
diff --git a/nomad/search.py b/nomad/search.py
index c5607acc02..f78ab5c15d 100644
--- a/nomad/search.py
+++ b/nomad/search.py
@@ -437,9 +437,20 @@ def validate_quantity(
     return quantity
 
 
+def _create_es_must(queries: Dict[str, EsQuery]):
+    # dictionary is like an "and" of all items in the dict
+    if len(queries) == 0:
+        return Q()
+
+    if len(queries) == 1:
+        return list(queries.values())[0]
+
+    return Q('bool', must=list(queries.values()))
+
+
 def validate_api_query(
         query: Query, doc_type: DocumentType, owner_query: EsQuery,
-        prefix: str = None) -> EsQuery:
+        prefix: str = None, results_dict: Dict[str, EsQuery] = None) -> EsQuery:
     '''
     Creates an ES query based on the API's query model. This needs to be a normalized
     query expression with explicit objects for logical, set, and comparison operators.
@@ -460,6 +471,11 @@ def validate_api_query(
             materials queries.
         prefix:
             An optional prefix that is added to all quantity names. Used for recursion.
+        results_dict:
+            If an empty dictionary is given and the query is a mapping, the top-level
+            criteria from this mapping will be added as individual es queries. The
+            keys will be the mapping keys and values the respective es queries. A logical
+            and (or es "must") would result in the overall resulting es query.
 
     Returns:
         A elasticsearch dsl query object.
@@ -564,11 +580,20 @@ def validate_api_query(
             return Q()
 
         if len(query) == 1:
-            key = next(iter(query))
-            return validate_criteria(key, query[key])
+            name = next(iter(query))
+            es_criteria_query = validate_criteria(name, query[name])
+            if results_dict is not None:
+                results_dict[name] = es_criteria_query
+            return es_criteria_query
 
-        return Q('bool', must=[
-            validate_criteria(name, value) for name, value in query.items()])
+        es_criteria_queries = []
+        for name, value in query.items():
+            es_criteria_query = validate_criteria(name, value)
+            es_criteria_queries.append(es_criteria_query)
+            if results_dict is not None:
+                results_dict[name] = es_criteria_query
+
+        return Q('bool', must=es_criteria_queries)
 
     raise NotImplementedError()
 
@@ -595,7 +620,8 @@ def validate_pagination(pagination: Pagination, doc_type: DocumentType, loc: Lis
 
 
 def _api_to_es_aggregation(
-        es_search: Search, name: str, agg: AggregationBase, doc_type: DocumentType) -> A:
+        es_search: Search, name: str, agg: AggregationBase, doc_type: DocumentType,
+        post_agg_queries: Dict[str, EsQuery]) -> A:
     '''
     Creates an ES aggregation based on the API's aggregation model.
     '''
@@ -603,6 +629,12 @@ def _api_to_es_aggregation(
     agg_name = f'agg:{name}'
     es_aggs = es_search.aggs
 
+    if post_agg_queries:
+        filter = post_agg_queries
+        if isinstance(agg, QuantityAggregation) and agg.exclude_from_search:
+            filter = {name: query for name, query in post_agg_queries.items() if name != agg.quantity}
+        es_aggs = es_aggs.bucket(f'{agg_name}:filtered', A('filter', filter=_create_es_must(filter)))
+
     if isinstance(agg, StatisticsAggregation):
         for metric_name in agg.metrics:
             metrics = doc_type.metrics
@@ -620,13 +652,11 @@ def _api_to_es_aggregation(
         return
 
     agg = cast(QuantityAggregation, agg)
-
     longest_nested_key = None
     quantity = validate_quantity(agg.quantity, doc_type=doc_type, loc=['aggregation', 'quantity'])
-
     for nested_key in doc_type.nested_object_keys:
         if agg.quantity.startswith(nested_key):
-            es_aggs = es_search.aggs.bucket('nested_agg:%s' % name, 'nested', path=nested_key)
+            es_aggs = es_aggs.bucket('nested_agg:%s' % name, 'nested', path=nested_key)
             longest_nested_key = nested_key
 
     es_agg = None
@@ -674,6 +704,11 @@ def _api_to_es_aggregation(
                 }
 
             if page_after_value is not None:
+                if post_agg_queries:
+                    raise QueryValidationError(
+                        f'aggregation page_after_value cannot be used with exclude_from_search in the same request',
+                        loc=['aggregations', name, 'terms', 'pagination', 'page_after_value'])
+
                 if order_quantity is None:
                     composite['after'] = {name: page_after_value}
                 else:
@@ -770,6 +805,11 @@ def _es_to_api_aggregation(
     the given aggregation.
     '''
     es_aggs = es_response.aggs
+
+    filtered_agg_name = f'agg:{name}:filtered'
+    if filtered_agg_name in es_response.aggs:
+        es_aggs = es_aggs[f'agg:{name}:filtered']
+
     aggregation_dict = agg.dict(by_alias=True)
 
     if isinstance(agg, StatisticsAggregation):
@@ -785,7 +825,7 @@ def _es_to_api_aggregation(
     longest_nested_key = None
     for nested_key in doc_type.nested_object_keys:
         if agg.quantity.startswith(nested_key):
-            es_aggs = es_response.aggs[f'nested_agg:{name}']
+            es_aggs = es_aggs[f'nested_agg:{name}']
             longest_nested_key = nested_key
 
     has_no_pagination = getattr(agg, 'pagination', None) is None
@@ -907,22 +947,24 @@ def search(
 
     doc_type = index.doc_type
 
-    # owner and query
+    # owner
     owner_query = _owner_es_query(owner=owner, user_id=user_id, doc_type=doc_type)
 
+    # query
     if query is None:
         query = {}
 
+    es_query_dict: Dict[str, EsQuery] = {}
     if isinstance(query, EsQuery):
         es_query = cast(EsQuery, query)
     else:
         es_query = validate_api_query(
-            cast(Query, query), doc_type=doc_type, owner_query=owner_query)
+            cast(Query, query), doc_type=doc_type, owner_query=owner_query,
+            results_dict=es_query_dict)
 
     if doc_type != entry_type:
-        es_query &= Q('nested', path='entries', query=owner_query)
-    else:
-        es_query &= owner_query
+        owner_query = Q('nested', path='entries', query=owner_query)
+    es_query &= owner_query
 
     # pagination
     if pagination is None:
@@ -933,7 +975,6 @@ def search(
 
     search = Search(index=index.index_name)
 
-    search = search.query(es_query)
     # TODO this depends on doc_type
     if pagination.order_by is None:
         pagination.order_by = doc_type.id_field
@@ -974,15 +1015,44 @@ def search(
         search = search.source(includes=required.include, excludes=required.exclude)
 
     # aggregations
-    for name, agg in aggregations.items():
-        _api_to_es_aggregation(search, name, _specific_agg(agg), doc_type=doc_type)
+    aggs = [(name, _specific_agg(agg)) for name, agg in aggregations.items()]
+    post_agg_queries: Dict[str, EsQuery] = {}
+    excluded_agg_quantities = {
+        agg.quantity
+        for _, agg in aggs
+        if isinstance(agg, QuantityAggregation) and agg.exclude_from_search}
+
+    if len(excluded_agg_quantities) > 0:
+        if not isinstance(query, dict):
+            # "exclude_from_search" only work for toplevel mapping queries
+            raise QueryValidationError(
+                f'the query has to be a dictionary if there is an aggregation with exclude_from_search',
+                loc=['query'])
+
+        pre_agg_queries = {
+            quantity: es_query
+            for quantity, es_query in es_query_dict.items()
+            if quantity not in excluded_agg_quantities}
+        post_agg_queries = {
+            quantity: es_query
+            for quantity, es_query in es_query_dict.items()
+            if quantity in excluded_agg_quantities}
+
+        search = search.post_filter(_create_es_must(post_agg_queries))
+        search = search.query(_create_es_must(pre_agg_queries) & owner_query)
+
+    else:
+        search = search.query(es_query)  # pylint: disable=no-member
+
+    for name, agg in aggs:
+        _api_to_es_aggregation(
+            search, name, agg, doc_type=doc_type, post_agg_queries=post_agg_queries)
 
     # execute
     try:
         es_response = search.execute()
     except RequestError as e:
         raise SearchError(e)
-
     more_response_data = {}
 
     # pagination
diff --git a/tests/app/v1/routers/common.py b/tests/app/v1/routers/common.py
index b3b69ea0e8..0916c7b6fe 100644
--- a/tests/app/v1/routers/common.py
+++ b/tests/app/v1/routers/common.py
@@ -333,6 +333,134 @@ def aggregation_test_parameters(entity_id: str, material_prefix: str, entry_pref
     ]
 
 
+def aggregation_exclude_from_search_test_parameters(entry_prefix: str, total_per_entity: int, total: int):
+    entry_id = f'{entry_prefix}entry_id'
+    upload_id = f'{entry_prefix}upload_id'
+    program_name = f'{entry_prefix}results.method.simulation.program_name'
+
+    return [
+        pytest.param(
+            {
+                f'{entry_id}:any': ['id_01']
+            },
+            [
+                {
+                    'exclude_from_search': True,
+                    'quantity': entry_id
+                }
+            ],
+            [10], 1, 200,
+            id='exclude'
+        ),
+        pytest.param(
+            {
+                f'{entry_id}:any': ['id_01']
+            },
+            [
+                {
+                    'exclude_from_search': False,
+                    'quantity': entry_id
+                }
+            ],
+            [total_per_entity], 1, 200,
+            id='dont-exclude'
+        ),
+        pytest.param(
+            {
+                f'{entry_id}:any': ['id_01'],
+                upload_id: 'id_published',
+                program_name: 'VASP'
+            },
+            [
+                {
+                    'exclude_from_search': True,
+                    'quantity': entry_id
+                },
+                {
+                    'exclude_from_search': True,
+                    'quantity': upload_id
+                }
+            ],
+            [10, 1], 1, 200,
+            id='two-aggs'
+        ),
+        pytest.param(
+            {
+                f'{entry_id}:any': ['id_01']
+            },
+            [
+                {
+                    'exclude_from_search': True,
+                    'quantity': entry_id
+                },
+                {
+                    'exclude_from_search': False,
+                    'quantity': entry_id
+                }
+            ],
+            [10, total_per_entity], 1, 200,
+            id='two-aggs-same-quantity'
+        ),
+        pytest.param(
+            {},
+            [
+                {
+                    'exclude_from_search': True,
+                    'quantity': entry_id
+                }
+            ],
+            [10], total, 200,
+            id='not-in-query'
+        ),
+        pytest.param(
+            {},
+            [
+                {
+                    'exclude_from_search': True,
+                    'quantity': entry_id,
+                    'pagination': {
+                        'page_size': 20
+                    }
+                }
+            ],
+            [20], total, 200,
+            id='with-pagination'
+        ),
+        pytest.param(
+            {
+                'or': [{entry_id: 'id_01'}]
+            },
+            [
+                {
+                    'exclude_from_search': True,
+                    'quantity': entry_id
+                }
+            ],
+            [0], 0, 422,
+            id='non-dict-query'
+        ),
+        pytest.param(
+            {
+                f'{entry_id}:any': ['id_01']
+            },
+            [
+                {
+                    'exclude_from_search': True,
+                    'quantity': entry_id
+                },
+                {
+                    'quantity': entry_id,
+                    'pagination': {
+                        'page_after_value': 'id_published'
+                    }
+                }
+            ],
+            [0], 0, 422,
+            id='with-page-after-value'
+        )
+    ]
+
+
 def assert_response(response, status_code=None):
     ''' General assertions for status_code and error messages '''
     if status_code and response.status_code != status_code:
diff --git a/tests/app/v1/routers/test_entries.py b/tests/app/v1/routers/test_entries.py
index 4a2dde5fd3..0e1d777c4a 100644
--- a/tests/app/v1/routers/test_entries.py
+++ b/tests/app/v1/routers/test_entries.py
@@ -29,8 +29,8 @@ from tests.utils import ExampleData
 from tests.test_files import example_mainfile_contents, append_raw_files  # pylint: disable=unused-import
 
 from .common import (
-    assert_response, assert_base_metadata_response, assert_metadata_response,
-    assert_required, assert_aggregations, assert_pagination,
+    aggregation_exclude_from_search_test_parameters, assert_response, assert_base_metadata_response,
+    assert_metadata_response, assert_required, assert_aggregations, assert_pagination,
     perform_metadata_test, post_query_test_parameters, get_query_test_parameters,
     perform_owner_test, owner_test_parameters, pagination_test_parameters,
     aggregation_test_parameters)
@@ -368,6 +368,27 @@ def test_entries_aggregations(client, data, test_user_auth, aggregation, total,
             default_key='entry_id')
 
 
+@pytest.mark.parametrize(
+    'query,aggs,agg_lengths,total,status_code',
+    aggregation_exclude_from_search_test_parameters(entry_prefix='', total_per_entity=1, total=23))
+def test_entries_aggregations_exclude_from_search(client, data, query, aggs, agg_lengths, total, status_code):
+    aggs = {f'agg_{i}': {'terms': agg} for i, agg in enumerate(aggs)}
+
+    response_json = perform_entries_metadata_test(
+        client, owner='visible',
+        query=query, aggregations=aggs,
+        pagination=dict(page_size=0),
+        status_code=status_code, http_method='post')
+
+    if response_json is None:
+        return
+
+    assert response_json['pagination']['total'] == total
+    for i, length in enumerate(agg_lengths):
+        response_agg = response_json['aggregations'][f'agg_{i}']['terms']
+    assert len(response_agg['data']) == length
+
+
 @pytest.mark.parametrize('required, status_code', [
     pytest.param({'include': ['entry_id', 'upload_id']}, 200, id='include'),
     pytest.param({'include': ['results.*', 'upload_id']}, 200, id='include-section'),
diff --git a/tests/app/v1/routers/test_materials.py b/tests/app/v1/routers/test_materials.py
index 32ed5fb514..970e68b09e 100644
--- a/tests/app/v1/routers/test_materials.py
+++ b/tests/app/v1/routers/test_materials.py
@@ -24,7 +24,7 @@ from nomad.metainfo.elasticsearch_extension import material_entry_type
 from tests.test_files import example_mainfile_contents  # pylint: disable=unused-import
 
 from .common import (
-    assert_pagination, assert_metadata_response, assert_required, assert_aggregations,
+    aggregation_exclude_from_search_test_parameters, assert_pagination, assert_metadata_response, assert_required, assert_aggregations,
     perform_metadata_test, perform_owner_test, owner_test_parameters,
     post_query_test_parameters, get_query_test_parameters, pagination_test_parameters,
     aggregation_test_parameters)
@@ -74,6 +74,27 @@ def test_materials_aggregations(client, data, test_user_auth, aggregation, total
             default_key='material_id')
 
 
+@pytest.mark.parametrize(
+    'query,aggs,agg_lengths,total,status_code',
+    aggregation_exclude_from_search_test_parameters(entry_prefix='entries.', total_per_entity=3, total=6))
+def test_materials_aggregations_exclude_from_search(client, data, query, aggs, agg_lengths, total, status_code):
+    aggs = {f'agg_{i}': {'terms': agg} for i, agg in enumerate(aggs)}
+
+    response_json = perform_materials_metadata_test(
+        client, owner='visible',
+        query=query, aggregations=aggs,
+        pagination=dict(page_size=0),
+        status_code=status_code, http_method='post')
+
+    if response_json is None:
+        return
+
+    assert response_json['pagination']['total'] == total
+    for i, length in enumerate(agg_lengths):
+        response_agg = response_json['aggregations'][f'agg_{i}']['terms']
+    assert len(response_agg['data']) == length
+
+
 @pytest.mark.parametrize('required, status_code', [
     pytest.param({'include': ['material_id', program_name]}, 200, id='include'),
     pytest.param({'include': ['entries.*', program_name]}, 200, id='include-section'),
-- 
GitLab