diff --git a/nomad/app/v1/models.py b/nomad/app/v1/models.py index b1f7c9453f28f7ec4198ca0cb8d1e91a56ebd2bc..a9f5ed467256dacbfbb99b26cbdde7155e6287a2 100644 --- a/nomad/app/v1/models.py +++ b/nomad/app/v1/models.py @@ -114,9 +114,10 @@ class Any_(NoneEmptyBaseModel, extra=Extra.forbid): class Range(BaseModel, extra=Extra.forbid): - """Represents a finite range which can have open or closed ends. Supports + ''' + Represents a finite range which can have open or closed ends. Supports several datatypes that have a well-defined comparison operator. - """ + ''' @root_validator def check_range_is_valid(cls, values): # pylint: disable=no-self-argument lt = values.get('lt') @@ -145,6 +146,19 @@ class Range(BaseModel, extra=Extra.forbid): gte: Optional[ComparableValue] = Field(None) +ops = { + 'lte': Range, + 'lt': Range, + 'gte': Range, + 'gt': Range, + 'all': All, + 'none': None_, + 'any': Any_ +} + +CriteriaValue = Union[Value, List[Value], Range, Any_, All, None_, Dict[str, Any]] + + class LogicalOperator(NoneEmptyBaseModel): @validator('op', check_fields=False) def validate_query(cls, query): # pylint: disable=no-self-argument @@ -167,6 +181,7 @@ class Not(LogicalOperator): class Nested(BaseModel): + prefix: str query: 'Query' @validator('query') @@ -174,19 +189,22 @@ class Nested(BaseModel): return _validate_query(query) -ops = { - 'lte': Range, - 'lt': Range, - 'gte': Range, - 'gt': Range, - 'all': All, - 'none': None_, - 'any': Any_ -} +class Criteria(BaseModel, extra=Extra.forbid): + name: str + value: CriteriaValue + + @validator('value') + def validate_query(cls, value, values): # pylint: disable=no-self-argument + name, value = _validate_criteria_value(values['name'], value) + values['name'] = name + return value + + +class Empty(BaseModel, extra=Extra.forbid): + pass -QueryParameterValue = Union[Value, List[Value], Range, Any_, All, None_, Nested, Dict[str, Any]] -Query = Union[And, Or, Not, Mapping[str, QueryParameterValue]] +Query = Union[And, Or, Not, Nested, Criteria, Empty, Mapping[str, CriteriaValue]] And.update_forward_refs() @@ -297,28 +315,29 @@ class WithQuery(BaseModel): return _validate_query(query) +def _validate_criteria_value(name: str, value: CriteriaValue): + if ':' in name: + quantity, qualifier = name.split(':') + else: + quantity, qualifier = name, None + + if qualifier is not None: + assert qualifier in ops, 'unknown quantity qualifier %s' % qualifier + return quantity, ops[qualifier](**{qualifier: value}) # type: ignore + elif isinstance(value, list): + return quantity, All(all=value) + else: + return quantity, value + + def _validate_query(query: Query): if isinstance(query, dict): for key, value in list(query.items()): - # Note, we loop over a list of items, not query.items(). This is because we - # may modify the query in the loop. - if isinstance(value, dict): - value = Nested(query=value) - - if ':' in key: - quantity, qualifier = key.split(':') - else: - quantity, qualifier = key, None - - if qualifier is not None: + quantity, value = _validate_criteria_value(key, value) + if quantity != key: assert quantity not in query, 'a quantity can only appear once in a query' - assert qualifier in ops, 'unknown quantity qualifier %s' % qualifier del(query[key]) - query[quantity] = ops[qualifier](**{qualifier: value}) # type: ignore - elif isinstance(value, list): - query[quantity] = All(all=value) - else: - query[quantity] = value + query[quantity] = value return query @@ -827,8 +846,9 @@ class TermsAggregation(BucketAggregation): ''')) size: Optional[pydantic.conint(gt=0)] = Field( # type: ignore None, description=strip(''' - Only the data few values are returned for each API call. Pagination allows to - get the next set of values based on the last value in the last call. + The ammount of aggregation values is limited. This allows you to configure the + maximum number of aggregated values to return. If you need to exaust all + possible value, use `pagination`. ''')) value_filter: Optional[pydantic.constr(regex=r'^[a-zA-Z0-9_\-\s]+$')] = Field( # type: ignore None, description=strip(''' diff --git a/nomad/search.py b/nomad/search.py index a7e525f2d17b5f3b89fc29ed0cd059abc9ac67bb..aaf11d835e493e67c403d04edae417b6465b8089 100644 --- a/nomad/search.py +++ b/nomad/search.py @@ -32,7 +32,7 @@ update the v1 materials index according to the performed changes. TODO this is o partially implemented. ''' -from typing import Union, List, Iterable, Any, cast, Dict, Iterator, Generator +from typing import Union, List, Iterable, Any, cast, Dict, Iterator, Generator, Callable import json import elasticsearch from elasticsearch.exceptions import TransportError, RequestError @@ -42,10 +42,10 @@ from pydantic.error_wrappers import ErrorWrapper from nomad import config, infrastructure, utils from nomad import datamodel +from nomad.app.v1 import models from nomad.datamodel import EntryArchive, EntryMetadata -from nomad.app.v1 import models as api_models from nomad.app.v1.models import ( - AggregationPagination, MetadataPagination, Pagination, PaginationResponse, + AggregationPagination, Criteria, MetadataPagination, Pagination, PaginationResponse, QuantityAggregation, Query, MetadataRequired, MetadataResponse, Aggregation, StatisticsAggregation, StatisticsAggregationResponse, Value, AggregationBase, TermsAggregation, BucketAggregation, HistogramAggregation, @@ -283,7 +283,7 @@ def _es_to_entry_dict(hit, required: MetadataRequired = None) -> Dict[str, Any]: return entry_dict -def _api_to_es_query(query: api_models.Query) -> Q: +def _api_to_es_query(query: models.Query) -> Q: ''' Creates an ES query based on the API's query model. This needs to be a normalized query expression with explicit objects for logical, set, and comparison operators. @@ -291,29 +291,29 @@ def _api_to_es_query(query: api_models.Query) -> Q: needs to be resolved via the respective pydantic validator. There is also no validation of quantities and types. ''' - def quantity_to_es(name: str, value: api_models.Value) -> Q: + def quantity_to_es(name: str, value: models.Value) -> Q: # TODO depends on keyword or not, value might need normalization, etc. quantity = entry_type.quantities[name] return Q('match', **{quantity.search_field: value}) - def parameter_to_es(name: str, value: api_models.QueryParameterValue) -> Q: + def parameter_to_es(name: str, value: models.CriteriaValue) -> Q: - if isinstance(value, api_models.All): + if isinstance(value, models.All): return Q('bool', must=[ quantity_to_es(name, item) for item in value.op]) - if isinstance(value, api_models.Any_): + if isinstance(value, models.Any_): return Q('bool', should=[ quantity_to_es(name, item) for item in value.op]) - if isinstance(value, api_models.None_): + if isinstance(value, models.None_): return Q('bool', must_not=[ quantity_to_es(name, item) for item in value.op]) - if isinstance(value, api_models.Range): + if isinstance(value, models.Range): quantity = entry_type.quantities[name] return Q('range', **{quantity.search_field: value.dict( exclude_unset=True, @@ -325,21 +325,24 @@ def _api_to_es_query(query: api_models.Query) -> Q: quantity_to_es(name, item) for item in value]) - return quantity_to_es(name, cast(api_models.Value, value)) + return quantity_to_es(name, cast(models.Value, value)) - def query_to_es(query: api_models.Query) -> Q: - if isinstance(query, api_models.LogicalOperator): - if isinstance(query, api_models.And): + def query_to_es(query: models.Query) -> Q: + if isinstance(query, models.LogicalOperator): + if isinstance(query, models.And): return Q('bool', must=[query_to_es(operand) for operand in query.op]) - if isinstance(query, api_models.Or): + if isinstance(query, models.Or): return Q('bool', should=[query_to_es(operand) for operand in query.op]) - if isinstance(query, api_models.Not): + if isinstance(query, models.Not): return Q('bool', must_not=query_to_es(query.op)) raise NotImplementedError() + if isinstance(query, models.Empty): + return Q() + if not isinstance(query, dict): raise NotImplementedError() @@ -437,25 +440,127 @@ def validate_quantity( return quantity -def _create_es_must(queries: Dict[str, EsQuery]): - # dictionary is like an "and" of all items in the dict - if len(queries) == 0: - return Q() +def normalize_api_query(query: Query, doc_type: DocumentType, prefix: str = None) -> Query: + ''' + Normalizes the given query. Should be applied before validate_api_query, which + expects a normalized query. Normalization will + - replace nested dicts with`models.And` or `models.Nested` instances + - introduce `models.Nested` if necessary + - replace dicts with `models.And` queries. + + After normalization there should be no dicts or `*:(any|all|none)` values in the query. + ''' + def normalize_criteria(name, value: models.CriteriaValue, prefix: str) -> Query: + if prefix is not None: + full_name = f'{prefix}.{name}' + else: + full_name = name + + prefixes = [] + name_wo_prefix = name + nested_prefix = None + for nested_key in doc_type.nested_object_keys: + if nested_key == prefix: + continue + + if full_name.startswith(f'{nested_key}'): + if prefix is None or not prefix.startswith(nested_key): + prefixes.append(nested_key) + name_wo_prefix = full_name[len(nested_key) + 1:] + nested_prefix = nested_key + + if full_name == nested_key: + break + name = name_wo_prefix + + query: Query = None + if isinstance(value, dict): + query = models.And(**{'and': [ + normalize_criteria(k if name == '' else f'{name}.{k}', v, nested_prefix) + for k, v in value.items()]}) + + else: + query = Criteria(name=name, value=value) + + for prefix in reversed(prefixes): + query = models.Nested(prefix=prefix, query=query) + + return query + + def normalize_query(query: Query): + return normalize_api_query(query, doc_type=doc_type, prefix=prefix) + + if isinstance(query, dict): + if len(query) is None: + return models.Empty() + + if len(query) == 1: + name = next(iter(query)) + return normalize_criteria(name, query[name], prefix) + + return models.And(**{'and': [ + normalize_criteria(name, value, prefix) for name, value in query.items()]}) + + if isinstance(query, models.And): + return models.And(**{'and': [normalize_query(op) for op in query.op]}) + + if isinstance(query, models.Or): + return models.Or(**{'or': [normalize_query(op) for op in query.op]}) + + if isinstance(query, models.Not): + return models.Not(**{'not': normalize_query(query.op)}) + + if isinstance(query, models.Nested): + return models.Nested( + prefix=query.prefix, + query=normalize_api_query(query, doc_type=doc_type, prefix=query.prefix)) + + if isinstance(query, (models.Empty, models.Criteria)): + return query + + raise NotImplementedError(f'Query type {query.__class__} is not supported') - if len(queries) == 1: - return list(queries.values())[0] - return Q('bool', must=list(queries.values())) +def remove_quantity_from_query(query: Query, quantity: str, prefix=None): + ''' + Removes all criteria with the given quantity from the query. Query has to be + normalized. Remove is done by replacing respective criteria with an empty query. + ''' + + if isinstance(query, models.And): + return models.And(**{'and': [remove_quantity_from_query(op, quantity, prefix) for op in query.op]}) + + if isinstance(query, models.Or): + return models.Or(**{'or': [remove_quantity_from_query(op, quantity, prefix) for op in query.op]}) + + if isinstance(query, models.Not): + return models.Not(**{'not': remove_quantity_from_query(query.op, quantity, prefix)}) + + if isinstance(query, models.Nested): + return models.Nested( + prefix=query.prefix, + query=remove_quantity_from_query(query.query, quantity, prefix=query.prefix)) + + if isinstance(query, models.Empty): + return query + + if isinstance(query, models.Criteria): + name = query.name + if prefix is not None: + name = f'{prefix}.{name}' + if name == quantity: + return models.Empty() + + return query + + raise NotImplementedError(f'Query type {query.__class__} is not supported') def validate_api_query( - query: Query, doc_type: DocumentType, owner_query: EsQuery, - prefix: str = None, results_dict: Dict[str, EsQuery] = None) -> EsQuery: + query: Query, doc_type: DocumentType, owner_query: EsQuery, prefix: str = None) -> EsQuery: ''' Creates an ES query based on the API's query model. This needs to be a normalized - query expression with explicit objects for logical, set, and comparison operators. - Shorthand notations ala ``quantity:operator`` are not supported here; this - needs to be resolved via the respective pydantic validator. + query. However, this function performs validation of quantities and types and raises a QueryValidationError accordingly. This exception is populated with pydantic @@ -471,11 +576,6 @@ def validate_api_query( materials queries. prefix: An optional prefix that is added to all quantity names. Used for recursion. - results_dict: - If an empty dictionary is given and the query is a mapping, the top-level - criteria from this mapping will be added as individual es queries. The - keys will be the mapping keys and values the respective es queries. A logical - and (or es "must") would result in the overall resulting es query. Returns: A elasticsearch dsl query object. @@ -484,6 +584,9 @@ def validate_api_query( ''' def match(name: str, value: Value) -> EsQuery: + if prefix is not None: + name = f'{prefix}.{name}' + if name == 'optimade_filter': value = str(value) from nomad.app.optimade import filterparser @@ -504,98 +607,60 @@ def validate_api_query( query, doc_type=doc_type, owner_query=owner_query, prefix=prefix) def validate_criteria(name: str, value: Any): - if prefix is not None: - name = f'{prefix}.{name}' - - # handle prefix and nested queries - for nested_key in doc_type.nested_object_keys: - if len(name) < len(nested_key): - break - if not name.startswith(nested_key): - continue - if prefix is not None and prefix.startswith(nested_key): - continue - if nested_key == name and isinstance(value, api_models.Nested): - continue - - value = api_models.Nested(query={name[len(nested_key) + 1:]: value}) - name = nested_key - break - - if isinstance(value, api_models.All): + if isinstance(value, models.All): return Q('bool', must=[match(name, item) for item in value.op]) - elif isinstance(value, api_models.Any_): + elif isinstance(value, models.Any_): return Q('bool', should=[match(name, item) for item in value.op]) - elif isinstance(value, api_models.None_): + elif isinstance(value, models.None_): return Q('bool', must_not=[match(name, item) for item in value.op]) - elif isinstance(value, api_models.Range): + elif isinstance(value, models.Range): quantity = validate_quantity(name, None, doc_type=doc_type) return Q('range', **{quantity.search_field: value.dict( exclude_unset=True, )}) - elif isinstance(value, (api_models.And, api_models.Or, api_models.Not)): + elif isinstance(value, (models.And, models.Or, models.Not)): return validate_query(value) - elif isinstance(value, api_models.Nested): - sub_doc_type = material_entry_type if name == 'entries' else doc_type - - sub_query = validate_api_query( - value.query, doc_type=sub_doc_type, prefix=name, owner_query=owner_query) - - if name in doc_type.nested_object_keys: - if name == 'entries': - sub_query &= owner_query - return Q('nested', path=name, query=sub_query) - else: - return sub_query - # list of values is treated as an "all" over the items elif isinstance(value, list): return Q('bool', must=[match(name, item) for item in value]) elif isinstance(value, dict): - assert False, ( - 'Using dictionaries as criteria values directly is not supported. Use the ' - 'Nested model.') + raise NotImplementedError() else: return match(name, value) - if isinstance(query, api_models.And): + if isinstance(query, models.And): return Q('bool', must=[validate_query(operand) for operand in query.op]) - if isinstance(query, api_models.Or): + if isinstance(query, models.Or): return Q('bool', should=[validate_query(operand) for operand in query.op]) - if isinstance(query, api_models.Not): + if isinstance(query, models.Not): return Q('bool', must_not=validate_query(query.op)) - if isinstance(query, dict): - # dictionary is like an "and" of all items in the dict - if len(query) == 0: - return Q() + if isinstance(query, models.Nested): + sub_doc_type = material_entry_type if query.prefix == 'entries' else doc_type + sub_query = validate_api_query( + query.query, doc_type=sub_doc_type, prefix=query.prefix, owner_query=owner_query) - if len(query) == 1: - name = next(iter(query)) - es_criteria_query = validate_criteria(name, query[name]) - if results_dict is not None: - results_dict[name] = es_criteria_query - return es_criteria_query + if query.prefix == 'entries': + sub_query &= owner_query - es_criteria_queries = [] - for name, value in query.items(): - es_criteria_query = validate_criteria(name, value) - es_criteria_queries.append(es_criteria_query) - if results_dict is not None: - results_dict[name] = es_criteria_query + return Q('nested', path=query.prefix, query=sub_query) - return Q('bool', must=es_criteria_queries) + if isinstance(query, models.Criteria): + return validate_criteria(query.name, query.value) - raise NotImplementedError() + if isinstance(query, models.Empty): + return Q() + + raise NotImplementedError(f'Query type {query.__class__} is not supported') def validate_pagination(pagination: Pagination, doc_type: DocumentType, loc: List[str] = None): @@ -621,7 +686,7 @@ def validate_pagination(pagination: Pagination, doc_type: DocumentType, loc: Lis def _api_to_es_aggregation( es_search: Search, name: str, agg: AggregationBase, doc_type: DocumentType, - post_agg_queries: Dict[str, EsQuery]) -> A: + post_agg_query: models.Query, create_es_query: Callable[[models.Query], EsQuery]) -> A: ''' Creates an ES aggregation based on the API's aggregation model. ''' @@ -629,11 +694,12 @@ def _api_to_es_aggregation( agg_name = f'agg:{name}' es_aggs = es_search.aggs - if post_agg_queries: - filter = post_agg_queries + if post_agg_query: if isinstance(agg, QuantityAggregation) and agg.exclude_from_search: - filter = {name: query for name, query in post_agg_queries.items() if name != agg.quantity} - es_aggs = es_aggs.bucket(f'{agg_name}:filtered', A('filter', filter=_create_es_must(filter))) + filter = create_es_query(remove_quantity_from_query(post_agg_query, agg.quantity)) + else: + filter = create_es_query(post_agg_query) + es_aggs = es_aggs.bucket(f'{agg_name}:filtered', A('filter', filter=filter)) if isinstance(agg, StatisticsAggregation): for metric_name in agg.metrics: @@ -670,6 +736,11 @@ def _api_to_es_aggregation( loc=['aggregation', name, 'terms', 'quantity']) if agg.pagination is not None: + if post_agg_query is not None: + raise QueryValidationError( + f'aggregation pagination cannot be used with exclude_from_search in the same request', + loc=['aggregations', name, 'terms', 'pagination']) + if agg.size is not None: raise QueryValidationError( f'You cannot paginate and provide an extra size parameter.', @@ -706,11 +777,6 @@ def _api_to_es_aggregation( } if page_after_value is not None: - if post_agg_queries: - raise QueryValidationError( - f'aggregation page_after_value cannot be used with exclude_from_search in the same request', - loc=['aggregations', name, 'terms', 'pagination', 'page_after_value']) - if order_quantity is None: composite['after'] = {name: page_after_value} else: @@ -965,17 +1031,19 @@ def search( if query is None: query = {} - es_query_dict: Dict[str, EsQuery] = {} + def create_es_query(query: Query): + return validate_api_query(cast(Query, query), doc_type=doc_type, owner_query=owner_query) + if isinstance(query, EsQuery): es_query = cast(EsQuery, query) else: - es_query = validate_api_query( - cast(Query, query), doc_type=doc_type, owner_query=owner_query, - results_dict=es_query_dict) + query = normalize_api_query(cast(Query, query), doc_type=doc_type) + es_query = create_es_query(cast(Query, query)) + nested_owner_query = owner_query if doc_type != entry_type: - owner_query = Q('nested', path='entries', query=owner_query) - es_query &= owner_query + nested_owner_query = Q('nested', path='entries', query=owner_query) + es_query &= nested_owner_query # pagination if pagination is None: @@ -1032,37 +1100,28 @@ def search( # aggregations aggs = [(name, _specific_agg(agg)) for name, agg in aggregations.items()] - post_agg_queries: Dict[str, EsQuery] = {} excluded_agg_quantities = { agg.quantity for _, agg in aggs if isinstance(agg, QuantityAggregation) and agg.exclude_from_search} if len(excluded_agg_quantities) > 0: - if not isinstance(query, dict): - # "exclude_from_search" only work for toplevel mapping queries - raise QueryValidationError( - f'the query has to be a dictionary if there is an aggregation with exclude_from_search', - loc=['query']) - - pre_agg_queries = { - quantity: es_query - for quantity, es_query in es_query_dict.items() - if quantity not in excluded_agg_quantities} - post_agg_queries = { - quantity: es_query - for quantity, es_query in es_query_dict.items() - if quantity in excluded_agg_quantities} + # TODO optimize and try to put a few criteria into pre_agg_query + pre_agg_es_query = Q() + post_agg_es_query = es_query + post_agg_query = query - search = search.post_filter(_create_es_must(post_agg_queries)) - search = search.query(_create_es_must(pre_agg_queries) & owner_query) + search = search.post_filter(post_agg_es_query) + search = search.query(pre_agg_es_query & nested_owner_query) else: search = search.query(es_query) # pylint: disable=no-member + post_agg_query = None for name, agg in aggs: _api_to_es_aggregation( - search, name, agg, doc_type=doc_type, post_agg_queries=post_agg_queries) + search, name, agg, doc_type=doc_type, + post_agg_query=post_agg_query, create_es_query=create_es_query) # execute try: diff --git a/tests/app/v1/routers/common.py b/tests/app/v1/routers/common.py index 0916c7b6fe3334dfdca91de16b4d47bb0021c85e..dbd528ee6d38f5a2d0aa013f9e18909af74f0b04 100644 --- a/tests/app/v1/routers/common.py +++ b/tests/app/v1/routers/common.py @@ -63,9 +63,9 @@ def post_query_test_parameters( pytest.param({'and': [{f'{entity_id}:any': ['id_01', 'id_02']}, {f'{entity_id}:any': ['id_02', 'id_03']}]}, 200, 1, id='and-nested-any'), pytest.param({'and': [{'not': {entity_id: 'id_01'}}, {'not': {entity_id: 'id_02'}}]}, 200, total - 2, id='not-nested-not'), pytest.param({method: {'simulation.program_name': 'VASP'}}, 200, total, id='inner-object'), - pytest.param({f'{properties}.electronic.dos_electronic.spin_polarized': True}, 200, 1, id='nested-implicit'), - pytest.param({f'{properties}.electronic.dos_electronic': {'spin_polarized': True}}, 200, 1, id='nested-explicit'), - pytest.param({properties: {'electronic.dos_electronic': {'spin_polarized': True}}}, 200, 1, id='nested-explicit-explicit'), + pytest.param({f'{properties}.electronic.dos_electronic.band_gap.type': 'direct'}, 200, 1, id='nested-implicit'), + pytest.param({f'{properties}.electronic.dos_electronic.band_gap': {'type': 'direct'}}, 200, 1, id='nested-explicit'), + pytest.param({properties: {'electronic.dos_electronic.band_gap': {'type': 'direct'}}}, 200, 1, id='nested-explicit-explicit'), pytest.param({f'{upload_create_time}:gt': '1970-01-01'}, 200, total, id='date-1'), pytest.param({f'{upload_create_time}:lt': '2099-01-01'}, 200, total, id='date-2'), pytest.param({f'{upload_create_time}:gt': '2099-01-01'}, 200, 0, id='date-3') @@ -339,6 +339,15 @@ def aggregation_exclude_from_search_test_parameters(entry_prefix: str, total_per program_name = f'{entry_prefix}results.method.simulation.program_name' return [ + pytest.param( + { + f'{entry_id}:any': ['id_01'], + upload_id: 'id_published', + program_name: 'VASP' + }, + [], [], 1, 200, + id='empty' + ), pytest.param( { f'{entry_id}:any': ['id_01'] @@ -423,40 +432,33 @@ def aggregation_exclude_from_search_test_parameters(entry_prefix: str, total_per } } ], - [20], total, 200, + [0], total, 422, id='with-pagination' ), pytest.param( - { - 'or': [{entry_id: 'id_01'}] - }, + {}, [ { 'exclude_from_search': True, - 'quantity': entry_id + 'quantity': entry_id, + 'size': 20 } ], - [0], 0, 422, - id='non-dict-query' + [20], total, 200, + id='with-size' ), pytest.param( { - f'{entry_id}:any': ['id_01'] + 'or': [{entry_id: 'id_01'}, {entry_id: 'id_05'}] }, [ { 'exclude_from_search': True, 'quantity': entry_id - }, - { - 'quantity': entry_id, - 'pagination': { - 'page_after_value': 'id_published' - } } ], - [0], 0, 422, - id='with-page-after-value' + [10], 2, 200, + id='non-dict-query' ) ] diff --git a/tests/app/v1/routers/test_entries.py b/tests/app/v1/routers/test_entries.py index b8dd813f4a4ad5b399e6d981b5d17d1cfa141065..6b13e9473b1be94a618c5799d96d9fdd7534b8ca 100644 --- a/tests/app/v1/routers/test_entries.py +++ b/tests/app/v1/routers/test_entries.py @@ -386,7 +386,7 @@ def test_entries_aggregations_exclude_from_search(client, data, query, aggs, agg assert response_json['pagination']['total'] == total for i, length in enumerate(agg_lengths): response_agg = response_json['aggregations'][f'agg_{i}']['terms'] - assert len(response_agg['data']) == length + assert len(response_agg['data']) == length @pytest.mark.parametrize('required, status_code', [ diff --git a/tests/app/v1/routers/test_materials.py b/tests/app/v1/routers/test_materials.py index 9014d311568a00e7200798da9a2e863eaa737690..3d8ec03644b59c8dab2d29a106e5a5dbb3ec3a62 100644 --- a/tests/app/v1/routers/test_materials.py +++ b/tests/app/v1/routers/test_materials.py @@ -97,7 +97,7 @@ def test_materials_aggregations_exclude_from_search(client, data, query, aggs, a assert response_json['pagination']['total'] == total for i, length in enumerate(agg_lengths): response_agg = response_json['aggregations'][f'agg_{i}']['terms'] - assert len(response_agg['data']) == length + assert len(response_agg['data']) == length @pytest.mark.parametrize('required, status_code', [ diff --git a/tests/utils.py b/tests/utils.py index 75d46380ba0146279aedc8902c9b4a7eeac34a45..294db89570b825fa427e7d4167347229dbca53f8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -298,7 +298,12 @@ class ExampleData: 'n_calculations': 1, 'electronic': { 'dos_electronic': { - 'spin_polarized': entry_id.endswith('04') + 'spin_polarized': entry_id.endswith('04'), + 'band_gap': [ + { + 'type': 'direct' if entry_id.endswith('04') else 'indirect' + } + ] } } }