Commit 6ad17658 authored by Markus Scheidgen's avatar Markus Scheidgen
Browse files

Added exclude from search option to aggregations. #573, #575

parent e029faf8
Pipeline #114491 passed with stages
in 29 minutes and 21 seconds
......@@ -788,6 +788,19 @@ class QuantityAggregation(AggregationBase):
The manatory name of the quantity for the aggregation. Aggregations
can only be computed for those search metadata that have discrete values;
an aggregation buckets entries that have the same value for this quantity.'''))
exclude_from_search: bool = Field(
False, description=strip('''
If set to true, top-level search criteria involving the aggregation quantity, will not
be applied for this aggregation. Therefore, the aggregation will return all
values for the quantity, even if the possible values where filtered by the query.
There are two limitations. This is only supported with queries that start with a
dictionary. It will not work for queries with a boolean operator. It can only
exclude top-level criteria at the root of the query dictionary. Nested criteria,
e.g. within complex and/or constructs, cannot be considered. Using this might also
prohibit pagination with page_after_value on aggregations in the same request.
''')
)
class BucketAggregation(QuantityAggregation):
......
......@@ -437,9 +437,20 @@ def validate_quantity(
return quantity
def _create_es_must(queries: Dict[str, EsQuery]):
# dictionary is like an "and" of all items in the dict
if len(queries) == 0:
return Q()
if len(queries) == 1:
return list(queries.values())[0]
return Q('bool', must=list(queries.values()))
def validate_api_query(
query: Query, doc_type: DocumentType, owner_query: EsQuery,
prefix: str = None) -> EsQuery:
prefix: str = None, results_dict: Dict[str, EsQuery] = None) -> EsQuery:
'''
Creates an ES query based on the API's query model. This needs to be a normalized
query expression with explicit objects for logical, set, and comparison operators.
......@@ -460,6 +471,11 @@ def validate_api_query(
materials queries.
prefix:
An optional prefix that is added to all quantity names. Used for recursion.
results_dict:
If an empty dictionary is given and the query is a mapping, the top-level
criteria from this mapping will be added as individual es queries. The
keys will be the mapping keys and values the respective es queries. A logical
and (or es "must") would result in the overall resulting es query.
Returns:
A elasticsearch dsl query object.
......@@ -564,11 +580,20 @@ def validate_api_query(
return Q()
if len(query) == 1:
key = next(iter(query))
return validate_criteria(key, query[key])
name = next(iter(query))
es_criteria_query = validate_criteria(name, query[name])
if results_dict is not None:
results_dict[name] = es_criteria_query
return es_criteria_query
return Q('bool', must=[
validate_criteria(name, value) for name, value in query.items()])
es_criteria_queries = []
for name, value in query.items():
es_criteria_query = validate_criteria(name, value)
es_criteria_queries.append(es_criteria_query)
if results_dict is not None:
results_dict[name] = es_criteria_query
return Q('bool', must=es_criteria_queries)
raise NotImplementedError()
......@@ -595,7 +620,8 @@ def validate_pagination(pagination: Pagination, doc_type: DocumentType, loc: Lis
def _api_to_es_aggregation(
es_search: Search, name: str, agg: AggregationBase, doc_type: DocumentType) -> A:
es_search: Search, name: str, agg: AggregationBase, doc_type: DocumentType,
post_agg_queries: Dict[str, EsQuery]) -> A:
'''
Creates an ES aggregation based on the API's aggregation model.
'''
......@@ -603,6 +629,12 @@ def _api_to_es_aggregation(
agg_name = f'agg:{name}'
es_aggs = es_search.aggs
if post_agg_queries:
filter = post_agg_queries
if isinstance(agg, QuantityAggregation) and agg.exclude_from_search:
filter = {name: query for name, query in post_agg_queries.items() if name != agg.quantity}
es_aggs = es_aggs.bucket(f'{agg_name}:filtered', A('filter', filter=_create_es_must(filter)))
if isinstance(agg, StatisticsAggregation):
for metric_name in agg.metrics:
metrics = doc_type.metrics
......@@ -620,13 +652,11 @@ def _api_to_es_aggregation(
return
agg = cast(QuantityAggregation, agg)
longest_nested_key = None
quantity = validate_quantity(agg.quantity, doc_type=doc_type, loc=['aggregation', 'quantity'])
for nested_key in doc_type.nested_object_keys:
if agg.quantity.startswith(nested_key):
es_aggs = es_search.aggs.bucket('nested_agg:%s' % name, 'nested', path=nested_key)
es_aggs = es_aggs.bucket('nested_agg:%s' % name, 'nested', path=nested_key)
longest_nested_key = nested_key
es_agg = None
......@@ -674,6 +704,11 @@ def _api_to_es_aggregation(
}
if page_after_value is not None:
if post_agg_queries:
raise QueryValidationError(
f'aggregation page_after_value cannot be used with exclude_from_search in the same request',
loc=['aggregations', name, 'terms', 'pagination', 'page_after_value'])
if order_quantity is None:
composite['after'] = {name: page_after_value}
else:
......@@ -770,6 +805,11 @@ def _es_to_api_aggregation(
the given aggregation.
'''
es_aggs = es_response.aggs
filtered_agg_name = f'agg:{name}:filtered'
if filtered_agg_name in es_response.aggs:
es_aggs = es_aggs[f'agg:{name}:filtered']
aggregation_dict = agg.dict(by_alias=True)
if isinstance(agg, StatisticsAggregation):
......@@ -785,7 +825,7 @@ def _es_to_api_aggregation(
longest_nested_key = None
for nested_key in doc_type.nested_object_keys:
if agg.quantity.startswith(nested_key):
es_aggs = es_response.aggs[f'nested_agg:{name}']
es_aggs = es_aggs[f'nested_agg:{name}']
longest_nested_key = nested_key
has_no_pagination = getattr(agg, 'pagination', None) is None
......@@ -907,22 +947,24 @@ def search(
doc_type = index.doc_type
# owner and query
# owner
owner_query = _owner_es_query(owner=owner, user_id=user_id, doc_type=doc_type)
# query
if query is None:
query = {}
es_query_dict: Dict[str, EsQuery] = {}
if isinstance(query, EsQuery):
es_query = cast(EsQuery, query)
else:
es_query = validate_api_query(
cast(Query, query), doc_type=doc_type, owner_query=owner_query)
cast(Query, query), doc_type=doc_type, owner_query=owner_query,
results_dict=es_query_dict)
if doc_type != entry_type:
es_query &= Q('nested', path='entries', query=owner_query)
else:
es_query &= owner_query
owner_query = Q('nested', path='entries', query=owner_query)
es_query &= owner_query
# pagination
if pagination is None:
......@@ -933,7 +975,6 @@ def search(
search = Search(index=index.index_name)
search = search.query(es_query)
# TODO this depends on doc_type
if pagination.order_by is None:
pagination.order_by = doc_type.id_field
......@@ -974,15 +1015,44 @@ def search(
search = search.source(includes=required.include, excludes=required.exclude)
# aggregations
for name, agg in aggregations.items():
_api_to_es_aggregation(search, name, _specific_agg(agg), doc_type=doc_type)
aggs = [(name, _specific_agg(agg)) for name, agg in aggregations.items()]
post_agg_queries: Dict[str, EsQuery] = {}
excluded_agg_quantities = {
agg.quantity
for _, agg in aggs
if isinstance(agg, QuantityAggregation) and agg.exclude_from_search}
if len(excluded_agg_quantities) > 0:
if not isinstance(query, dict):
# "exclude_from_search" only work for toplevel mapping queries
raise QueryValidationError(
f'the query has to be a dictionary if there is an aggregation with exclude_from_search',
loc=['query'])
pre_agg_queries = {
quantity: es_query
for quantity, es_query in es_query_dict.items()
if quantity not in excluded_agg_quantities}
post_agg_queries = {
quantity: es_query
for quantity, es_query in es_query_dict.items()
if quantity in excluded_agg_quantities}
search = search.post_filter(_create_es_must(post_agg_queries))
search = search.query(_create_es_must(pre_agg_queries) & owner_query)
else:
search = search.query(es_query) # pylint: disable=no-member
for name, agg in aggs:
_api_to_es_aggregation(
search, name, agg, doc_type=doc_type, post_agg_queries=post_agg_queries)
# execute
try:
es_response = search.execute()
except RequestError as e:
raise SearchError(e)
more_response_data = {}
# pagination
......
......@@ -333,6 +333,134 @@ def aggregation_test_parameters(entity_id: str, material_prefix: str, entry_pref
]
def aggregation_exclude_from_search_test_parameters(entry_prefix: str, total_per_entity: int, total: int):
entry_id = f'{entry_prefix}entry_id'
upload_id = f'{entry_prefix}upload_id'
program_name = f'{entry_prefix}results.method.simulation.program_name'
return [
pytest.param(
{
f'{entry_id}:any': ['id_01']
},
[
{
'exclude_from_search': True,
'quantity': entry_id
}
],
[10], 1, 200,
id='exclude'
),
pytest.param(
{
f'{entry_id}:any': ['id_01']
},
[
{
'exclude_from_search': False,
'quantity': entry_id
}
],
[total_per_entity], 1, 200,
id='dont-exclude'
),
pytest.param(
{
f'{entry_id}:any': ['id_01'],
upload_id: 'id_published',
program_name: 'VASP'
},
[
{
'exclude_from_search': True,
'quantity': entry_id
},
{
'exclude_from_search': True,
'quantity': upload_id
}
],
[10, 1], 1, 200,
id='two-aggs'
),
pytest.param(
{
f'{entry_id}:any': ['id_01']
},
[
{
'exclude_from_search': True,
'quantity': entry_id
},
{
'exclude_from_search': False,
'quantity': entry_id
}
],
[10, total_per_entity], 1, 200,
id='two-aggs-same-quantity'
),
pytest.param(
{},
[
{
'exclude_from_search': True,
'quantity': entry_id
}
],
[10], total, 200,
id='not-in-query'
),
pytest.param(
{},
[
{
'exclude_from_search': True,
'quantity': entry_id,
'pagination': {
'page_size': 20
}
}
],
[20], total, 200,
id='with-pagination'
),
pytest.param(
{
'or': [{entry_id: 'id_01'}]
},
[
{
'exclude_from_search': True,
'quantity': entry_id
}
],
[0], 0, 422,
id='non-dict-query'
),
pytest.param(
{
f'{entry_id}:any': ['id_01']
},
[
{
'exclude_from_search': True,
'quantity': entry_id
},
{
'quantity': entry_id,
'pagination': {
'page_after_value': 'id_published'
}
}
],
[0], 0, 422,
id='with-page-after-value'
)
]
def assert_response(response, status_code=None):
''' General assertions for status_code and error messages '''
if status_code and response.status_code != status_code:
......
......@@ -29,8 +29,8 @@ from tests.utils import ExampleData
from tests.test_files import example_mainfile_contents, append_raw_files # pylint: disable=unused-import
from .common import (
assert_response, assert_base_metadata_response, assert_metadata_response,
assert_required, assert_aggregations, assert_pagination,
aggregation_exclude_from_search_test_parameters, assert_response, assert_base_metadata_response,
assert_metadata_response, assert_required, assert_aggregations, assert_pagination,
perform_metadata_test, post_query_test_parameters, get_query_test_parameters,
perform_owner_test, owner_test_parameters, pagination_test_parameters,
aggregation_test_parameters)
......@@ -368,6 +368,27 @@ def test_entries_aggregations(client, data, test_user_auth, aggregation, total,
default_key='entry_id')
@pytest.mark.parametrize(
'query,aggs,agg_lengths,total,status_code',
aggregation_exclude_from_search_test_parameters(entry_prefix='', total_per_entity=1, total=23))
def test_entries_aggregations_exclude_from_search(client, data, query, aggs, agg_lengths, total, status_code):
aggs = {f'agg_{i}': {'terms': agg} for i, agg in enumerate(aggs)}
response_json = perform_entries_metadata_test(
client, owner='visible',
query=query, aggregations=aggs,
pagination=dict(page_size=0),
status_code=status_code, http_method='post')
if response_json is None:
return
assert response_json['pagination']['total'] == total
for i, length in enumerate(agg_lengths):
response_agg = response_json['aggregations'][f'agg_{i}']['terms']
assert len(response_agg['data']) == length
@pytest.mark.parametrize('required, status_code', [
pytest.param({'include': ['entry_id', 'upload_id']}, 200, id='include'),
pytest.param({'include': ['results.*', 'upload_id']}, 200, id='include-section'),
......
......@@ -24,7 +24,7 @@ from nomad.metainfo.elasticsearch_extension import material_entry_type
from tests.test_files import example_mainfile_contents # pylint: disable=unused-import
from .common import (
assert_pagination, assert_metadata_response, assert_required, assert_aggregations,
aggregation_exclude_from_search_test_parameters, assert_pagination, assert_metadata_response, assert_required, assert_aggregations,
perform_metadata_test, perform_owner_test, owner_test_parameters,
post_query_test_parameters, get_query_test_parameters, pagination_test_parameters,
aggregation_test_parameters)
......@@ -74,6 +74,27 @@ def test_materials_aggregations(client, data, test_user_auth, aggregation, total
default_key='material_id')
@pytest.mark.parametrize(
'query,aggs,agg_lengths,total,status_code',
aggregation_exclude_from_search_test_parameters(entry_prefix='entries.', total_per_entity=3, total=6))
def test_materials_aggregations_exclude_from_search(client, data, query, aggs, agg_lengths, total, status_code):
aggs = {f'agg_{i}': {'terms': agg} for i, agg in enumerate(aggs)}
response_json = perform_materials_metadata_test(
client, owner='visible',
query=query, aggregations=aggs,
pagination=dict(page_size=0),
status_code=status_code, http_method='post')
if response_json is None:
return
assert response_json['pagination']['total'] == total
for i, length in enumerate(agg_lengths):
response_agg = response_json['aggregations'][f'agg_{i}']['terms']
assert len(response_agg['data']) == length
@pytest.mark.parametrize('required, status_code', [
pytest.param({'include': ['material_id', program_name]}, 200, id='include'),
pytest.param({'include': ['entries.*', program_name]}, 200, id='include-section'),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment