diff --git a/nomad/app/api/common.py b/nomad/app/api/common.py index fa6bfd70148ae4691fd1fd848a6598660510dab6..1471e69e5df3be4eb765111d5f40c6bedd22977c 100644 --- a/nomad/app/api/common.py +++ b/nomad/app/api/common.py @@ -25,6 +25,7 @@ import sys import os.path from nomad import search, config +from nomad.datamodel import Domain from nomad.app.optimade import filterparser from nomad.app.common import RFC3339DateTime, rfc3339DateTime from nomad.files import Restricted @@ -107,12 +108,16 @@ def add_search_parameters(request_parser): help='A yyyy-MM-ddTHH:mm:ss (RFC3339) maximum entry time (e.g. upload time)') # main search parameters - for quantity in search.quantities.values(): + for quantity in Domain.all_quantities(): request_parser.add_argument( - quantity.name, help=quantity.description, + quantity.qualified_name, help=quantity.description, action=quantity.argparse_action if quantity.multi else None) +_search_quantities = set([ + domain.qualified_name for domain in Domain.all_quantities()]) + + def apply_search_parameters(search_request: search.SearchRequest, args: Dict[str, Any]): """ Help that adds query relevant request args to the given SearchRequest. @@ -153,7 +158,7 @@ def apply_search_parameters(search_request: search.SearchRequest, args: Dict[str # search parameter search_request.search_parameters(**{ key: value for key, value in args.items() - if key not in ['optimade'] and key in search.quantities}) + if key in _search_quantities}) def calc_route(ns, prefix: str = ''): diff --git a/nomad/app/api/repo.py b/nomad/app/api/repo.py index b5f5df82248c0dcb5b64dc541088fe6ac9799835..e48d8d12e59404d2b9fe0c6363956caa6a43a4ce 100644 --- a/nomad/app/api/repo.py +++ b/nomad/app/api/repo.py @@ -34,7 +34,7 @@ from .api import api from .auth import authenticate from .common import search_model, calc_route, add_pagination_parameters,\ add_scroll_parameters, add_search_parameters, apply_search_parameters,\ - query_api_python, query_api_curl + query_api_python, query_api_curl, _search_quantities ns = api.namespace('repo', description='Access repository metadata.') @@ -264,7 +264,7 @@ _query_model_parameters = { 'until_time': RFC3339DateTime(description='A yyyy-MM-ddTHH:mm:ss (RFC3339) maximum entry time (e.g. upload time)') } -for quantity in search.quantities.values(): +for quantity in datamodel.Domain.all_quantities(): if quantity.multi and quantity.argparse_action is None: def field(**kwargs): return fields.List(fields.String(**kwargs)) @@ -379,9 +379,9 @@ class EditRepoCalcsResource(Resource): # preparing the query of entries that are edited parsed_query = {} - for quantity_name, quantity in search.quantities.items(): - if quantity_name in query: - value = query[quantity_name] + for quantity_name, value in query.items(): + if quantity_name in _search_quantities: + quantity = datamodel.Domain.get_quantity(quantity_name) if quantity.multi and quantity.argparse_action == 'split' and not isinstance(value, list): value = value.split(',') parsed_query[quantity_name] = value diff --git a/nomad/cli/client/statistics.py b/nomad/cli/client/statistics.py index 4189e6210b3d8a81d2ee4e813424380c385dc4a5..f903deeb4dbc6af973c81931cc5e75f17e8ef030 100644 --- a/nomad/cli/client/statistics.py +++ b/nomad/cli/client/statistics.py @@ -405,12 +405,12 @@ def statistics_table(html, geometries, public_path): # search calcs quantities=section_k_band band_structures = get_statistic( - client.repo.search(per_page=1, quantities=['section_k_band']).response().result, + client.repo.search(per_page=1, **{'dft.quantities': ['section_k_band']}).response().result, 'total', 'all', 'code_runs') # search calcs quantities=section_dos dos = get_statistic( - client.repo.search(per_page=1, quantities=['section_dos']).response().result, + client.repo.search(per_page=1, **{'dft.quantities': ['section_dos']}).response().result, 'total', 'all', 'code_runs') phonons = get_statistic( diff --git a/nomad/datamodel/base.py b/nomad/datamodel/base.py index 75918c6c0c4166904ac17381e955b0bfe4cbd8a8..35eed394692d08724572e7b55d764d756d061a33 100644 --- a/nomad/datamodel/base.py +++ b/nomad/datamodel/base.py @@ -244,10 +244,7 @@ class DomainQuantity: @property def name(self) -> str: - if self.domain is not None: - return '%s.%s' % (self.domain, self._name) - else: - return self._name + return self._name @name.setter def name(self, name: str) -> None: @@ -257,6 +254,20 @@ class DomainQuantity: if self.elastic_field is None: self.elastic_field = self.name + @property + def qualified_elastic_field(self) -> str: + if self.domain is None: + return self.elastic_field + else: + return '%s.%s' % (self.domain, self.elastic_field) + + @property + def qualified_name(self) -> str: + if self.domain is None: + return self.name + else: + return '%s.%s' % (self.domain, self.name) + class Domain: """ @@ -354,6 +365,25 @@ class Domain: datasets=('dataset_id', 'datasets'), uploads=('upload_id', 'uploads')) + @classmethod + def get_quantity(cls, name_spec) -> DomainQuantity: + """ + Returns the quantity definition for the given quantity name. The name can be the + qualified name (``domain.quantity``) or in Django-style (``domain__quantity``). + """ + qualified_name = name_spec.replace('__', '.') + split_name = qualified_name.split('.') + if len(split_name) == 1: + return cls.base_quantities[split_name[0]] + elif len(split_name) == 2: + return cls.instances[split_name[0]].quantities[split_name[1]] + else: + assert False, 'qualified quantity name depth must be 2 max' + + @classmethod + def all_quantities(cls) -> Iterable[DomainQuantity]: + return set([quantity for domain in cls.instances.values() for quantity in domain.quantities.values()]) + def __init__( self, name: str, domain_entry_class: Type[CalcWithMetadata], quantities: Dict[str, DomainQuantity], @@ -446,7 +476,7 @@ class Domain: def order_default_quantity(self) -> str: for quantity in self.quantities.values(): if quantity.order_default: - return quantity.name + return quantity.qualified_name assert False, 'each domain must defina an order_default quantity' diff --git a/nomad/search.py b/nomad/search.py index 5714f84801a47bae8dccbb8e7538c958850fce60..d0623c1181ca1e621d969322cb26f528cfcd951e 100644 --- a/nomad/search.py +++ b/nomad/search.py @@ -26,6 +26,7 @@ from datetime import datetime import json from nomad import config, datamodel, infrastructure, datamodel, utils, processing as proc +from nomad.datamodel import Domain path_analyzer = analyzer( @@ -69,12 +70,23 @@ class Dataset(InnerDoc): name = Keyword() +_domain_inner_doc_types: Dict[str, type] = {} + + class WithDomain(IndexMeta): """ Override elasticsearch_dsl metaclass to sneak in domain specific mappings """ def __new__(cls, name, bases, attrs): - for domain in datamodel.Domain.instances.values(): - for quantity in domain.domain_quantities.values(): - attrs[quantity.elastic_field] = quantity.elastic_mapping + for domain in Domain.instances.values(): + inner_doc_type = _domain_inner_doc_types.get(domain.name) + if inner_doc_type is None: + domain_attrs = { + quantity.elastic_field: quantity.elastic_mapping + for quantity in domain.domain_quantities.values()} + + inner_doc_type = type(domain.name, (InnerDoc,), domain_attrs) + _domain_inner_doc_types[domain.name] = inner_doc_type + + attrs[domain.name] = Object(inner_doc_type) return super(WithDomain, cls).__new__(cls, name, bases, attrs) @@ -162,9 +174,13 @@ class Entry(Document, metaclass=WithDomain): self.external_id = source.external_id if self.domain is not None: - for quantity in datamodel.Domain.instances[self.domain].domain_quantities.values(): + inner_doc_type = _domain_inner_doc_types[self.domain] + inner_doc = inner_doc_type() + for quantity in Domain.instances[self.domain].domain_quantities.values(): quantity_value = quantity.elastic_value(getattr(source, quantity.metadata_field)) - setattr(self, quantity.name, quantity_value) + setattr(inner_doc, quantity.elastic_field, quantity_value) + + setattr(self, self.domain, inner_doc) def delete_upload(upload_id): @@ -221,15 +237,9 @@ def refresh(): infrastructure.elastic_client.indices.refresh(config.elastic.index_name) -quantities = { - quantity_name: quantity - for domain in datamodel.Domain.instances.values() - for quantity_name, quantity in domain.quantities.items()} -"""The available search quantities """ - metrics = { metric_name: metric - for domain in datamodel.Domain.instances.values() + for domain in Domain.instances.values() for metric_name, metric in domain.metrics.items()} """ The available search metrics. Metrics are integer values given for each entry that can @@ -237,23 +247,23 @@ be used in statistics (aggregations), e.g. the sum of all total energy calculati all unique geometries. """ -metrics_names = [metric_name for domain in datamodel.Domain.instances.values() for metric_name in domain.metrics_names] +metrics_names = [metric_name for domain in Domain.instances.values() for metric_name in domain.metrics_names] """ Names of all available metrics """ groups = { key: value - for domain in datamodel.Domain.instances.values() + for domain in Domain.instances.values() for key, value in domain.groups.items()} """The available groupable quantities""" order_default_quantities = { domain_name: domain.order_default_quantity - for domain_name, domain in datamodel.Domain.instances.items() + for domain_name, domain in Domain.instances.items() } default_statistics = { domain_name: domain.default_statistics - for domain_name, domain in datamodel.Domain.instances.items() + for domain_name, domain in Domain.instances.items() } @@ -353,9 +363,7 @@ class SearchRequest: return self def search_parameter(self, name, value): - quantity = quantities.get(name, None) - if quantity is None: - raise KeyError('Unknown quantity %s' % name) + quantity = Domain.get_quantity(name) if quantity.multi and not isinstance(value, list): value = [value] @@ -365,7 +373,7 @@ class SearchRequest: if quantity.elastic_search_type == 'terms': if not isinstance(value, list): value = [value] - self.q &= Q('terms', **{quantity.elastic_field: value}) + self.q &= Q('terms', **{quantity.qualified_elastic_field: value}) return self @@ -375,7 +383,7 @@ class SearchRequest: values = [value] for item in values: - self.q &= Q(quantity.elastic_search_type, **{quantity.elastic_field: item}) + self.q &= Q(quantity.elastic_search_type, **{quantity.qualified_elastic_field: item}) return self @@ -429,7 +437,7 @@ class SearchRequest: for name in default_statistics[self._domain]: self.statistic( name, - quantities[name].aggregations, + Domain.get_quantity(name).aggregations, metrics_to_use=metrics_to_use) return self @@ -459,8 +467,8 @@ class SearchRequest: ``unique_code_runs``, ``datasets``, other domain specific metrics. The basic doc_count metric ``code_runs`` is always given. """ - quantity = quantities[quantity_name] - terms = A('terms', field=quantity.elastic_field, size=size, order=dict(_key='asc')) + quantity = Domain.get_quantity(quantity_name) + terms = A('terms', field=quantity.qualified_elastic_field, size=size, order=dict(_key='asc')) buckets = self._search.aggs.bucket('statistics:%s' % quantity_name, terms) self._add_metrics(buckets, metrics_to_use) @@ -532,8 +540,8 @@ class SearchRequest: if size is None: size = 100 - quantity = quantities[name] - terms = A('terms', field=quantity.elastic_field) + quantity = Domain.get_quantity(name) + terms = A('terms', field=quantity.qualified_elastic_field) # We are using elastic searchs 'composite aggregations' here. We do not really # compose aggregations, but only those pseudo composites allow us to use the @@ -585,15 +593,12 @@ class SearchRequest: search = self._search.query(self.q) if order_by is not None: - if order_by not in quantities: - raise KeyError('Unknown order quantity %s' % order_by) - - order_by_quantity = quantities[order_by] + order_by_quantity = Domain.get_quantity(order_by) if order == 1: - search = search.sort(order_by_quantity.elastic_field) + search = search.sort(order_by_quantity.qualified_elastic_field) else: - search = search.sort('-%s' % order_by_quantity.elastic_field) + search = search.sort('-%s' % order_by_quantity.qualified_elastic_field) search = search.params(preserve_order=True) @@ -617,15 +622,12 @@ class SearchRequest: search = self._search.query(self.q) - if order_by not in quantities: - raise KeyError('Unknown order quantity %s' % order_by) - - order_by_quantity = quantities[order_by] + order_by_quantity = Domain.get_quantity(order_by) if order == 1: - search = search.sort(order_by_quantity.elastic_field) + search = search.sort(order_by_quantity.qualified_elastic_field) else: - search = search.sort('-%s' % order_by_quantity.elastic_field) + search = search.sort('-%s' % order_by_quantity.qualified_elastic_field) search = search[(page - 1) * per_page: page * per_page] result = self._response(search.execute(), with_hits=True) @@ -778,3 +780,23 @@ def to_calc_with_metadata(results: List[Dict[str, Any]]): return [ datamodel.CalcWithMetadata(**calc.metadata) for calc in proc.Calc.objects(calc_id__in=ids)] + + +def flat(obj, prefix=None): + """ + Helper that translates nested result objects into flattened dicts with + ``domain.quantity`` as keys. + """ + if isinstance(obj, dict): + result = {} + for key, value in obj.items(): + if isinstance(value, dict): + for child_key, child_value in value.items(): + result['%s.%s' % (key, child_key)] = flat(child_value) + + else: + result[key] = value + + return result + else: + return obj diff --git a/tests/app/test_api.py b/tests/app/test_api.py index 25e8751ff9a3847f8f98bdd8ab3626d15ee98323..74cb9fe1f0eb3c4c7980e4cdaecabdf9e67b1df6 100644 --- a/tests/app/test_api.py +++ b/tests/app/test_api.py @@ -224,7 +224,7 @@ class TestUploads: assert calc['current_task'] == 'archiving' assert len(calc['tasks']) == 3 - assert 'dft.atoms' in calc['metadata'] + assert 'dft.atoms' in search.flat(calc['metadata']) assert api.get('/archive/logs/%s/%s' % (calc['upload_id'], calc['calc_id']), headers=test_user_auth).status_code == 200 if upload['calcs']['pagination']['total'] > 1: @@ -814,10 +814,11 @@ class TestRepo(): auth = dict(none=None, test_user=test_user_auth, other_test_user=other_test_user_auth).get(auth) rv = api.get('/repo/?owner=%s' % owner, headers=auth) data = self.assert_search(rv, calcs) - results = data.get('results', None) if calcs > 0: - for key in ['uploader', 'calc_id', 'dft.formula', 'upload_id']: - assert key in results[0] + results = data.get('results', None) + result = search.flat(results[0]) + for key in ['uploader.name', 'calc_id', 'dft.formula', 'upload_id']: + assert key in result @pytest.mark.parametrize('calcs, start, end', [ (2, today - datetime.timedelta(days=6), today), @@ -887,7 +888,7 @@ class TestRepo(): def test_search_exclude(self, api, example_elastic_calcs, no_warn): rv = api.get('/repo/?exclude=dft.atoms,dft.only_atoms') assert rv.status_code == 200 - result = json.loads(rv.data)['results'][0] + result = search.flat(json.loads(rv.data)['results'][0]) assert 'dft.atoms' not in result assert 'dft.only_atoms' not in result assert 'dft.basis_set' in result diff --git a/tests/app/test_optimade.py b/tests/app/test_optimade.py index ee2ce4ac56928304fc0fa7dd2d91359bb03c3157..f63a7d7bd0d035e50ea9be6408aa8f9a52b921d0 100644 --- a/tests/app/test_optimade.py +++ b/tests/app/test_optimade.py @@ -36,7 +36,7 @@ def test_get_entry(published: Upload): data = json.load(f) assert 'OptimadeEntry' in data search_result = search.SearchRequest().search_parameter('calc_id', calc_id).execute_paginated()['results'][0] - assert 'dft.optimade' in search_result + assert 'dft.optimade' in search.flat(search_result) def test_no_optimade(meta_info, elastic, api): diff --git a/tests/conftest.py b/tests/conftest.py index 44e976602838caabc225c10f427697d9f5e211b6..3abe3b7edb9d5e5fcd618a26be45d432ee12d132 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -617,7 +617,7 @@ def published_wo_user_metadata(non_empty_processed: processing.Upload) -> proces @pytest.fixture def reset_config(): - """ Fixture that resets the log-level after test. """ + """ Fixture that resets configuration. """ service = config.service log_level = config.console_log_level yield None @@ -626,6 +626,12 @@ def reset_config(): infrastructure.setup_logging() +@pytest.fixture +def reset_infra(mongo, elastic): + """ Fixture that resets infrastructure after deleting db or search index. """ + yield None + + def create_test_structure( meta_info, id: int, h: int, o: int, extra: List[str], periodicity: int, optimade: bool = True, metadata: dict = None): diff --git a/tests/test_cli.py b/tests/test_cli.py index 8ce39dc0c0dace10acdcbabe4823f93ee323cc63..04f7ce671b663b3445028747ba9f7acf7d4162ef 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -16,7 +16,6 @@ import pytest import click.testing import json -import mongoengine from nomad import utils, search, processing as proc, files from nomad.cli import cli @@ -29,26 +28,26 @@ from tests.app.test_app import BlueprintClient @pytest.mark.usefixtures('reset_config', 'no_warn') class TestAdmin: - def test_reset(self): + def test_reset(self, reset_infra): result = click.testing.CliRunner().invoke( cli, ['admin', 'reset', '--i-am-really-sure'], catch_exceptions=False, obj=utils.POPO()) assert result.exit_code == 0 # allow other test to re-establish a connection - mongoengine.disconnect_all() + # mongoengine.disconnect_all() def test_reset_not_sure(self): result = click.testing.CliRunner().invoke( cli, ['admin', 'reset'], catch_exceptions=False, obj=utils.POPO()) assert result.exit_code == 1 - def test_remove(self): - result = click.testing.CliRunner().invoke( - cli, ['admin', 'reset', '--remove', '--i-am-really-sure'], catch_exceptions=False, obj=utils.POPO()) - assert result.exit_code == 0 + # def test_remove(self, reset_infra): + # result = click.testing.CliRunner().invoke( + # cli, ['admin', 'reset', '--remove', '--i-am-really-sure'], catch_exceptions=False, obj=utils.POPO()) + # assert result.exit_code == 0 - # allow other test to re-establish a connection - mongoengine.disconnect_all() + # # allow other test to re-establish a connection + # mongoengine.disconnect_all() def test_clean(self, published): upload_id = published.upload_id @@ -213,7 +212,6 @@ class TestAdminUploads: assert upload.tasks_status == proc.PENDING assert calc.tasks_status == proc.PENDING - @pytest.mark.usefixtures('reset_config') class TestClient: @@ -254,7 +252,9 @@ class TestClient: @pytest.mark.parametrize('move, link', [(True, False), (False, True), (False, False)]) def test_mirror(self, published, admin_user_bravado_client, monkeypatch, move, link): - ref_search_results = search.SearchRequest().search_parameters(upload_id=published.upload_id).execute_paginated()['results'][0] + ref_search_results = search.flat( + search.SearchRequest().search_parameters( + upload_id=published.upload_id).execute_paginated()['results'][0]) monkeypatch.setattr('nomad.cli.client.mirror.__in_test', True) @@ -277,9 +277,9 @@ class TestClient: calcs_in_search = new_search['pagination']['total'] assert calcs_in_search == 1 - new_search_results = new_search['results'][0] + new_search_results = search.flat(new_search['results'][0]) for key in new_search_results.keys(): - if key not in ['upload_time', 'last_processing', 'dft.labels']: + if key not in ['upload_time', 'last_processing', 'dft.labels.label']: # There is a sub second change due to date conversions (?). # Labels have arbitrary order. assert json.dumps(new_search_results[key]) == json.dumps(ref_search_results[key]) @@ -288,7 +288,9 @@ class TestClient: proc.Upload.objects(upload_id=published.upload_id).first().upload_files.exists def test_mirror_staging(self, non_empty_processed, admin_user_bravado_client, monkeypatch): - ref_search_results = search.SearchRequest().search_parameters(upload_id=non_empty_processed.upload_id).execute_paginated()['results'][0] + ref_search_results = search.flat( + search.SearchRequest().search_parameters( + upload_id=non_empty_processed.upload_id).execute_paginated()['results'][0]) monkeypatch.setattr('nomad.cli.client.mirror.__in_test', True) @@ -304,7 +306,7 @@ class TestClient: calcs_in_search = new_search['pagination']['total'] assert calcs_in_search == 1 - new_search_results = new_search['results'][0] + new_search_results = search.flat(new_search['results'][0]) for key in new_search_results.keys(): if key not in ['upload_time', 'last_processing']: # There is a sub second change due to date conversions (?) assert json.dumps(new_search_results[key]) == json.dumps(ref_search_results[key]) diff --git a/tests/test_search.py b/tests/test_search.py index 9b1f2d0e29c9209c72bf1c6391282911b0c25d03..36e8b124ec9ab18dd89d65fa1032f108ff9bffcc 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -36,11 +36,11 @@ def test_index_normalized_calc(elastic, normalized: parsing.LocalBackend): domain='dft', upload_id='test upload id', calc_id='test id') calc_with_metadata.apply_domain_metadata(normalized) - entry = create_entry(calc_with_metadata) + entry = search.flat(create_entry(calc_with_metadata).to_dict()) - assert getattr(entry, 'calc_id') is not None - assert getattr(entry, 'dft.atoms') is not None - assert getattr(entry, 'dft.code_name') is not None + assert 'calc_id' in entry + assert 'dft.atoms' in entry + assert 'dft.code_name' in entry def test_index_normalized_calc_with_metadata( @@ -151,17 +151,18 @@ def test_search_totals(elastic, example_search_data): def test_search_exclude(elastic, example_search_data): for item in SearchRequest().execute_paginated()['results']: - assert 'dft.atoms' in item + assert 'dft.atoms' in search.flat(item) for item in SearchRequest().exclude('dft.atoms').execute_paginated()['results']: - assert 'dft.atoms' not in item + assert 'dft.atoms' not in search.flat(item) def test_search_include(elastic, example_search_data): for item in SearchRequest().execute_paginated()['results']: - assert 'dft.atoms' in item + assert 'dft.atoms' in search.flat(item) for item in SearchRequest().include('calc_id').execute_paginated()['results']: + item = search.flat(item) assert 'dft.atoms' not in item assert 'calc_id' in item @@ -220,11 +221,11 @@ def assert_entry(calc_id): def assert_search_upload(upload: datamodel.UploadWithMetadata, additional_keys: List[str] = [], **kwargs): keys = ['calc_id', 'upload_id', 'mainfile', 'calc_hash'] refresh_index() - search = Entry.search().query('match_all')[0:10] - assert search.count() == len(list(upload.calcs)) - if search.count() > 0: - for hit in search: - hit = hit.to_dict() + search_results = Entry.search().query('match_all')[0:10] + assert search_results.count() == len(list(upload.calcs)) + if search_results.count() > 0: + for hit in search_results: + hit = search.flat(hit.to_dict()) for key, value in kwargs.items(): assert hit.get(key, None) == value