Commit 78b30041 authored by Markus Scheidgen's avatar Markus Scheidgen
Browse files

Refactored search index to use nested objects for domain data.

parent e3d0351c
......@@ -25,6 +25,7 @@ import sys
import os.path
from nomad import search, config
from nomad.datamodel import Domain
from nomad.app.optimade import filterparser
from nomad.app.common import RFC3339DateTime, rfc3339DateTime
from nomad.files import Restricted
......@@ -107,12 +108,16 @@ def add_search_parameters(request_parser):
help='A yyyy-MM-ddTHH:mm:ss (RFC3339) maximum entry time (e.g. upload time)')
# main search parameters
for quantity in search.quantities.values():
for quantity in Domain.all_quantities():
request_parser.add_argument(
quantity.name, help=quantity.description,
quantity.qualified_name, help=quantity.description,
action=quantity.argparse_action if quantity.multi else None)
_search_quantities = set([
domain.qualified_name for domain in Domain.all_quantities()])
def apply_search_parameters(search_request: search.SearchRequest, args: Dict[str, Any]):
"""
Help that adds query relevant request args to the given SearchRequest.
......@@ -153,7 +158,7 @@ def apply_search_parameters(search_request: search.SearchRequest, args: Dict[str
# search parameter
search_request.search_parameters(**{
key: value for key, value in args.items()
if key not in ['optimade'] and key in search.quantities})
if key in _search_quantities})
def calc_route(ns, prefix: str = ''):
......
......@@ -34,7 +34,7 @@ from .api import api
from .auth import authenticate
from .common import search_model, calc_route, add_pagination_parameters,\
add_scroll_parameters, add_search_parameters, apply_search_parameters,\
query_api_python, query_api_curl
query_api_python, query_api_curl, _search_quantities
ns = api.namespace('repo', description='Access repository metadata.')
......@@ -264,7 +264,7 @@ _query_model_parameters = {
'until_time': RFC3339DateTime(description='A yyyy-MM-ddTHH:mm:ss (RFC3339) maximum entry time (e.g. upload time)')
}
for quantity in search.quantities.values():
for quantity in datamodel.Domain.all_quantities():
if quantity.multi and quantity.argparse_action is None:
def field(**kwargs):
return fields.List(fields.String(**kwargs))
......@@ -379,9 +379,9 @@ class EditRepoCalcsResource(Resource):
# preparing the query of entries that are edited
parsed_query = {}
for quantity_name, quantity in search.quantities.items():
if quantity_name in query:
value = query[quantity_name]
for quantity_name, value in query.items():
if quantity_name in _search_quantities:
quantity = datamodel.Domain.get_quantity(quantity_name)
if quantity.multi and quantity.argparse_action == 'split' and not isinstance(value, list):
value = value.split(',')
parsed_query[quantity_name] = value
......
......@@ -405,12 +405,12 @@ def statistics_table(html, geometries, public_path):
# search calcs quantities=section_k_band
band_structures = get_statistic(
client.repo.search(per_page=1, quantities=['section_k_band']).response().result,
client.repo.search(per_page=1, **{'dft.quantities': ['section_k_band']}).response().result,
'total', 'all', 'code_runs')
# search calcs quantities=section_dos
dos = get_statistic(
client.repo.search(per_page=1, quantities=['section_dos']).response().result,
client.repo.search(per_page=1, **{'dft.quantities': ['section_dos']}).response().result,
'total', 'all', 'code_runs')
phonons = get_statistic(
......
......@@ -244,10 +244,7 @@ class DomainQuantity:
@property
def name(self) -> str:
if self.domain is not None:
return '%s.%s' % (self.domain, self._name)
else:
return self._name
return self._name
@name.setter
def name(self, name: str) -> None:
......@@ -257,6 +254,20 @@ class DomainQuantity:
if self.elastic_field is None:
self.elastic_field = self.name
@property
def qualified_elastic_field(self) -> str:
if self.domain is None:
return self.elastic_field
else:
return '%s.%s' % (self.domain, self.elastic_field)
@property
def qualified_name(self) -> str:
if self.domain is None:
return self.name
else:
return '%s.%s' % (self.domain, self.name)
class Domain:
"""
......@@ -354,6 +365,25 @@ class Domain:
datasets=('dataset_id', 'datasets'),
uploads=('upload_id', 'uploads'))
@classmethod
def get_quantity(cls, name_spec) -> DomainQuantity:
"""
Returns the quantity definition for the given quantity name. The name can be the
qualified name (``domain.quantity``) or in Django-style (``domain__quantity``).
"""
qualified_name = name_spec.replace('__', '.')
split_name = qualified_name.split('.')
if len(split_name) == 1:
return cls.base_quantities[split_name[0]]
elif len(split_name) == 2:
return cls.instances[split_name[0]].quantities[split_name[1]]
else:
assert False, 'qualified quantity name depth must be 2 max'
@classmethod
def all_quantities(cls) -> Iterable[DomainQuantity]:
return set([quantity for domain in cls.instances.values() for quantity in domain.quantities.values()])
def __init__(
self, name: str, domain_entry_class: Type[CalcWithMetadata],
quantities: Dict[str, DomainQuantity],
......@@ -446,7 +476,7 @@ class Domain:
def order_default_quantity(self) -> str:
for quantity in self.quantities.values():
if quantity.order_default:
return quantity.name
return quantity.qualified_name
assert False, 'each domain must defina an order_default quantity'
......
......@@ -26,6 +26,7 @@ from datetime import datetime
import json
from nomad import config, datamodel, infrastructure, datamodel, utils, processing as proc
from nomad.datamodel import Domain
path_analyzer = analyzer(
......@@ -69,12 +70,23 @@ class Dataset(InnerDoc):
name = Keyword()
_domain_inner_doc_types: Dict[str, type] = {}
class WithDomain(IndexMeta):
""" Override elasticsearch_dsl metaclass to sneak in domain specific mappings """
def __new__(cls, name, bases, attrs):
for domain in datamodel.Domain.instances.values():
for quantity in domain.domain_quantities.values():
attrs[quantity.elastic_field] = quantity.elastic_mapping
for domain in Domain.instances.values():
inner_doc_type = _domain_inner_doc_types.get(domain.name)
if inner_doc_type is None:
domain_attrs = {
quantity.elastic_field: quantity.elastic_mapping
for quantity in domain.domain_quantities.values()}
inner_doc_type = type(domain.name, (InnerDoc,), domain_attrs)
_domain_inner_doc_types[domain.name] = inner_doc_type
attrs[domain.name] = Object(inner_doc_type)
return super(WithDomain, cls).__new__(cls, name, bases, attrs)
......@@ -162,9 +174,13 @@ class Entry(Document, metaclass=WithDomain):
self.external_id = source.external_id
if self.domain is not None:
for quantity in datamodel.Domain.instances[self.domain].domain_quantities.values():
inner_doc_type = _domain_inner_doc_types[self.domain]
inner_doc = inner_doc_type()
for quantity in Domain.instances[self.domain].domain_quantities.values():
quantity_value = quantity.elastic_value(getattr(source, quantity.metadata_field))
setattr(self, quantity.name, quantity_value)
setattr(inner_doc, quantity.elastic_field, quantity_value)
setattr(self, self.domain, inner_doc)
def delete_upload(upload_id):
......@@ -221,15 +237,9 @@ def refresh():
infrastructure.elastic_client.indices.refresh(config.elastic.index_name)
quantities = {
quantity_name: quantity
for domain in datamodel.Domain.instances.values()
for quantity_name, quantity in domain.quantities.items()}
"""The available search quantities """
metrics = {
metric_name: metric
for domain in datamodel.Domain.instances.values()
for domain in Domain.instances.values()
for metric_name, metric in domain.metrics.items()}
"""
The available search metrics. Metrics are integer values given for each entry that can
......@@ -237,23 +247,23 @@ be used in statistics (aggregations), e.g. the sum of all total energy calculati
all unique geometries.
"""
metrics_names = [metric_name for domain in datamodel.Domain.instances.values() for metric_name in domain.metrics_names]
metrics_names = [metric_name for domain in Domain.instances.values() for metric_name in domain.metrics_names]
""" Names of all available metrics """
groups = {
key: value
for domain in datamodel.Domain.instances.values()
for domain in Domain.instances.values()
for key, value in domain.groups.items()}
"""The available groupable quantities"""
order_default_quantities = {
domain_name: domain.order_default_quantity
for domain_name, domain in datamodel.Domain.instances.items()
for domain_name, domain in Domain.instances.items()
}
default_statistics = {
domain_name: domain.default_statistics
for domain_name, domain in datamodel.Domain.instances.items()
for domain_name, domain in Domain.instances.items()
}
......@@ -353,9 +363,7 @@ class SearchRequest:
return self
def search_parameter(self, name, value):
quantity = quantities.get(name, None)
if quantity is None:
raise KeyError('Unknown quantity %s' % name)
quantity = Domain.get_quantity(name)
if quantity.multi and not isinstance(value, list):
value = [value]
......@@ -365,7 +373,7 @@ class SearchRequest:
if quantity.elastic_search_type == 'terms':
if not isinstance(value, list):
value = [value]
self.q &= Q('terms', **{quantity.elastic_field: value})
self.q &= Q('terms', **{quantity.qualified_elastic_field: value})
return self
......@@ -375,7 +383,7 @@ class SearchRequest:
values = [value]
for item in values:
self.q &= Q(quantity.elastic_search_type, **{quantity.elastic_field: item})
self.q &= Q(quantity.elastic_search_type, **{quantity.qualified_elastic_field: item})
return self
......@@ -429,7 +437,7 @@ class SearchRequest:
for name in default_statistics[self._domain]:
self.statistic(
name,
quantities[name].aggregations,
Domain.get_quantity(name).aggregations,
metrics_to_use=metrics_to_use)
return self
......@@ -459,8 +467,8 @@ class SearchRequest:
``unique_code_runs``, ``datasets``, other domain specific metrics.
The basic doc_count metric ``code_runs`` is always given.
"""
quantity = quantities[quantity_name]
terms = A('terms', field=quantity.elastic_field, size=size, order=dict(_key='asc'))
quantity = Domain.get_quantity(quantity_name)
terms = A('terms', field=quantity.qualified_elastic_field, size=size, order=dict(_key='asc'))
buckets = self._search.aggs.bucket('statistics:%s' % quantity_name, terms)
self._add_metrics(buckets, metrics_to_use)
......@@ -532,8 +540,8 @@ class SearchRequest:
if size is None:
size = 100
quantity = quantities[name]
terms = A('terms', field=quantity.elastic_field)
quantity = Domain.get_quantity(name)
terms = A('terms', field=quantity.qualified_elastic_field)
# We are using elastic searchs 'composite aggregations' here. We do not really
# compose aggregations, but only those pseudo composites allow us to use the
......@@ -585,15 +593,12 @@ class SearchRequest:
search = self._search.query(self.q)
if order_by is not None:
if order_by not in quantities:
raise KeyError('Unknown order quantity %s' % order_by)
order_by_quantity = quantities[order_by]
order_by_quantity = Domain.get_quantity(order_by)
if order == 1:
search = search.sort(order_by_quantity.elastic_field)
search = search.sort(order_by_quantity.qualified_elastic_field)
else:
search = search.sort('-%s' % order_by_quantity.elastic_field)
search = search.sort('-%s' % order_by_quantity.qualified_elastic_field)
search = search.params(preserve_order=True)
......@@ -617,15 +622,12 @@ class SearchRequest:
search = self._search.query(self.q)
if order_by not in quantities:
raise KeyError('Unknown order quantity %s' % order_by)
order_by_quantity = quantities[order_by]
order_by_quantity = Domain.get_quantity(order_by)
if order == 1:
search = search.sort(order_by_quantity.elastic_field)
search = search.sort(order_by_quantity.qualified_elastic_field)
else:
search = search.sort('-%s' % order_by_quantity.elastic_field)
search = search.sort('-%s' % order_by_quantity.qualified_elastic_field)
search = search[(page - 1) * per_page: page * per_page]
result = self._response(search.execute(), with_hits=True)
......@@ -778,3 +780,23 @@ def to_calc_with_metadata(results: List[Dict[str, Any]]):
return [
datamodel.CalcWithMetadata(**calc.metadata)
for calc in proc.Calc.objects(calc_id__in=ids)]
def flat(obj, prefix=None):
"""
Helper that translates nested result objects into flattened dicts with
``domain.quantity`` as keys.
"""
if isinstance(obj, dict):
result = {}
for key, value in obj.items():
if isinstance(value, dict):
for child_key, child_value in value.items():
result['%s.%s' % (key, child_key)] = flat(child_value)
else:
result[key] = value
return result
else:
return obj
......@@ -224,7 +224,7 @@ class TestUploads:
assert calc['current_task'] == 'archiving'
assert len(calc['tasks']) == 3
assert 'dft.atoms' in calc['metadata']
assert 'dft.atoms' in search.flat(calc['metadata'])
assert api.get('/archive/logs/%s/%s' % (calc['upload_id'], calc['calc_id']), headers=test_user_auth).status_code == 200
if upload['calcs']['pagination']['total'] > 1:
......@@ -814,10 +814,11 @@ class TestRepo():
auth = dict(none=None, test_user=test_user_auth, other_test_user=other_test_user_auth).get(auth)
rv = api.get('/repo/?owner=%s' % owner, headers=auth)
data = self.assert_search(rv, calcs)
results = data.get('results', None)
if calcs > 0:
for key in ['uploader', 'calc_id', 'dft.formula', 'upload_id']:
assert key in results[0]
results = data.get('results', None)
result = search.flat(results[0])
for key in ['uploader.name', 'calc_id', 'dft.formula', 'upload_id']:
assert key in result
@pytest.mark.parametrize('calcs, start, end', [
(2, today - datetime.timedelta(days=6), today),
......@@ -887,7 +888,7 @@ class TestRepo():
def test_search_exclude(self, api, example_elastic_calcs, no_warn):
rv = api.get('/repo/?exclude=dft.atoms,dft.only_atoms')
assert rv.status_code == 200
result = json.loads(rv.data)['results'][0]
result = search.flat(json.loads(rv.data)['results'][0])
assert 'dft.atoms' not in result
assert 'dft.only_atoms' not in result
assert 'dft.basis_set' in result
......
......@@ -36,7 +36,7 @@ def test_get_entry(published: Upload):
data = json.load(f)
assert 'OptimadeEntry' in data
search_result = search.SearchRequest().search_parameter('calc_id', calc_id).execute_paginated()['results'][0]
assert 'dft.optimade' in search_result
assert 'dft.optimade' in search.flat(search_result)
def test_no_optimade(meta_info, elastic, api):
......
......@@ -617,7 +617,7 @@ def published_wo_user_metadata(non_empty_processed: processing.Upload) -> proces
@pytest.fixture
def reset_config():
""" Fixture that resets the log-level after test. """
""" Fixture that resets configuration. """
service = config.service
log_level = config.console_log_level
yield None
......@@ -626,6 +626,12 @@ def reset_config():
infrastructure.setup_logging()
@pytest.fixture
def reset_infra(mongo, elastic):
""" Fixture that resets infrastructure after deleting db or search index. """
yield None
def create_test_structure(
meta_info, id: int, h: int, o: int, extra: List[str], periodicity: int,
optimade: bool = True, metadata: dict = None):
......
......@@ -16,7 +16,6 @@
import pytest
import click.testing
import json
import mongoengine
from nomad import utils, search, processing as proc, files
from nomad.cli import cli
......@@ -29,26 +28,26 @@ from tests.app.test_app import BlueprintClient
@pytest.mark.usefixtures('reset_config', 'no_warn')
class TestAdmin:
def test_reset(self):
def test_reset(self, reset_infra):
result = click.testing.CliRunner().invoke(
cli, ['admin', 'reset', '--i-am-really-sure'], catch_exceptions=False, obj=utils.POPO())
assert result.exit_code == 0
# allow other test to re-establish a connection
mongoengine.disconnect_all()
# mongoengine.disconnect_all()
def test_reset_not_sure(self):
result = click.testing.CliRunner().invoke(
cli, ['admin', 'reset'], catch_exceptions=False, obj=utils.POPO())
assert result.exit_code == 1
def test_remove(self):
result = click.testing.CliRunner().invoke(
cli, ['admin', 'reset', '--remove', '--i-am-really-sure'], catch_exceptions=False, obj=utils.POPO())
assert result.exit_code == 0
# def test_remove(self, reset_infra):
# result = click.testing.CliRunner().invoke(
# cli, ['admin', 'reset', '--remove', '--i-am-really-sure'], catch_exceptions=False, obj=utils.POPO())
# assert result.exit_code == 0
# allow other test to re-establish a connection
mongoengine.disconnect_all()
# # allow other test to re-establish a connection
# mongoengine.disconnect_all()
def test_clean(self, published):
upload_id = published.upload_id
......@@ -213,7 +212,6 @@ class TestAdminUploads:
assert upload.tasks_status == proc.PENDING
assert calc.tasks_status == proc.PENDING
@pytest.mark.usefixtures('reset_config')
class TestClient:
......@@ -254,7 +252,9 @@ class TestClient:
@pytest.mark.parametrize('move, link', [(True, False), (False, True), (False, False)])
def test_mirror(self, published, admin_user_bravado_client, monkeypatch, move, link):
ref_search_results = search.SearchRequest().search_parameters(upload_id=published.upload_id).execute_paginated()['results'][0]
ref_search_results = search.flat(
search.SearchRequest().search_parameters(
upload_id=published.upload_id).execute_paginated()['results'][0])
monkeypatch.setattr('nomad.cli.client.mirror.__in_test', True)
......@@ -277,9 +277,9 @@ class TestClient:
calcs_in_search = new_search['pagination']['total']
assert calcs_in_search == 1
new_search_results = new_search['results'][0]
new_search_results = search.flat(new_search['results'][0])
for key in new_search_results.keys():
if key not in ['upload_time', 'last_processing', 'dft.labels']:
if key not in ['upload_time', 'last_processing', 'dft.labels.label']:
# There is a sub second change due to date conversions (?).
# Labels have arbitrary order.
assert json.dumps(new_search_results[key]) == json.dumps(ref_search_results[key])
......@@ -288,7 +288,9 @@ class TestClient:
proc.Upload.objects(upload_id=published.upload_id).first().upload_files.exists
def test_mirror_staging(self, non_empty_processed, admin_user_bravado_client, monkeypatch):
ref_search_results = search.SearchRequest().search_parameters(upload_id=non_empty_processed.upload_id).execute_paginated()['results'][0]
ref_search_results = search.flat(
search.SearchRequest().search_parameters(
upload_id=non_empty_processed.upload_id).execute_paginated()['results'][0])
monkeypatch.setattr('nomad.cli.client.mirror.__in_test', True)
......@@ -304,7 +306,7 @@ class TestClient:
calcs_in_search = new_search['pagination']['total']
assert calcs_in_search == 1
new_search_results = new_search['results'][0]
new_search_results = search.flat(new_search['results'][0])
for key in new_search_results.keys():
if key not in ['upload_time', 'last_processing']: # There is a sub second change due to date conversions (?)
assert json.dumps(new_search_results[key]) == json.dumps(ref_search_results[key])
......
......@@ -36,11 +36,11 @@ def test_index_normalized_calc(elastic, normalized: parsing.LocalBackend):
domain='dft', upload_id='test upload id', calc_id='test id')
calc_with_metadata.apply_domain_metadata(normalized)
entry = create_entry(calc_with_metadata)
entry = search.flat(create_entry(calc_with_metadata).to_dict())
assert getattr(entry, 'calc_id') is not None
assert getattr(entry, 'dft.atoms') is not None
assert getattr(entry, 'dft.code_name') is not None
assert 'calc_id' in entry
assert 'dft.atoms' in entry
assert 'dft.code_name' in entry
def test_index_normalized_calc_with_metadata(
......@@ -151,17 +151,18 @@ def test_search_totals(elastic, example_search_data):
def test_search_exclude(elastic, example_search_data):
for item in SearchRequest().execute_paginated()['results']:
assert 'dft.atoms' in item
assert 'dft.atoms' in search.flat(item)
for item in SearchRequest().exclude('dft.atoms').execute_paginated()['results']:
assert 'dft.atoms' not in item
assert 'dft.atoms' not in search.flat(item)
def test_search_include(elastic, example_search_data):
for item in SearchRequest().execute_paginated()['results']:
assert 'dft.atoms' in item
assert 'dft.atoms' in search.flat(item)
for item in SearchRequest().include('calc_id').execute_paginated()['results']:
item = search.flat(item)
assert 'dft.atoms' not in item
assert 'calc_id' in item
......@@ -220,11 +221,11 @@ def assert_entry(calc_id):
def assert_search_upload(upload: datamodel.UploadWithMetadata, additional_keys: List[str] = [], **kwargs):
keys = ['calc_id', 'upload_id', 'mainfile', 'calc_hash']
refresh_index()
search = Entry.search().query('match_all')[0:10]
assert search.count() == len(list(upload.calcs))
if search.count() > 0:
for hit in search:
hit = hit.to_dict()
search_results = Entry.search().query('match_all')[0:10]
assert search_results.count() == len(list(upload.calcs))
if search_results.count() > 0:
for hit in search_results:
hit = search.flat(hit.to_dict())
for key, value in kwargs.items():
assert hit.get(key, None) == value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment