Commit 507ad087 authored by Alvin Noe Ladines's avatar Alvin Noe Ladines
Browse files

Implemented complex query

parent 49b5fa3e
Pipeline #80664 passed with stages
in 24 minutes and 43 seconds
......@@ -32,7 +32,8 @@ from nomad.app import common
from .auth import authenticate, create_authorization_predicate
from .api import api
from .common import calc_route, streamed_zipfile, search_model, add_search_parameters, apply_search_parameters, query_model
from .common import calc_route, streamed_zipfile, search_model, add_search_parameters,\
apply_search_parameters
ns = api.namespace(
......@@ -212,7 +213,6 @@ class ArchiveDownloadResource(Resource):
_archive_query_model = api.inherit('ArchiveSearch', search_model, {
'query': fields.Nested(query_model, description='The query used to find the requested entries.', skip_none=True),
'required': fields.Raw(description='A dictionary that defines what archive data to retrive.'),
'query_schema': fields.Raw(description='Deprecated, use required instead.'),
'raise_errors': fields.Boolean(description='Return 404 on missing archives or 500 on other errors instead of skipping the entry.')
......@@ -253,6 +253,8 @@ class ArchiveQueryResource(Resource):
query = data_in.get('query', {})
query_expression = {key: val for key, val in query.items() if '$' in key}
required: Dict[str, Any] = None
if 'required' in data_in:
required = data_in.get('required')
......@@ -277,6 +279,9 @@ class ArchiveQueryResource(Resource):
if not aggregation:
search_request.include('calc_id', 'upload_id', 'with_embargo', 'published', 'parser_name')
if query_expression:
search_request.query_expression(query_expression)
try:
if aggregation:
results = search_request.execute_aggregated(
......
......@@ -76,7 +76,51 @@ aggregation_model = api.model('Aggregation', {
'per_page': fields.Integer(default=0, help='The size of the requested page.', allow_null=True)})
''' Model used in responses with id aggregation. '''
query_model_fields = {
qualified_name: quantity.flask_field
for qualified_name, quantity in search.search_quantities.items()}
query_model_fields.update(**{
'owner': fields.String(description='The group the calculations belong to.', allow_null=True, skip_none=True),
'domain': fields.String(description='Specify the domain to search in: %s, default is ``%s``' % (
', '.join(['``%s``' % domain for domain in datamodel.domains]), config.meta.default_domain)),
'from_time': fields.Raw(description='The minimum entry time.', allow_null=True, skip_none=True),
'until_time': fields.Raw(description='The maximum entry time.', allow_null=True, skip_none=True)
})
query_model = api.model('Query0', query_model_fields)
for n in range(1, 3):
query_model = api.model('Query%d' % n, {**query_model_fields, **{
'$and': fields.List(fields.Nested(query_model, allow_null=True, skip_none=True), description=(
'List of queries which must be present in search results.'
)),
'$or': fields.List(fields.Nested(query_model, allow_null=True, skip_none=True, description=(
'List of queries which should be present in search results.'
))),
'$not': fields.List(fields.Nested(query_model, allow_null=True, skip_none=True, description=(
'List of queries which must not be present in search results.'
))),
'$lt': fields.Nested(api.model('Querylt', query_model_fields), allow_null=True, skip_none=True, description=(
'Dict of quantiy name: value such that search results should have values '
'less than value.'
)),
'$lte': fields.Nested(api.model('Querylte', query_model_fields), allow_null=True, skip_none=True, description=(
'Dict of quantiy name: value such that search results should have values '
'less than or equal to value'
)),
'$gt': fields.Nested(api.model('Querygt', query_model_fields), allow_null=True, skip_none=True, description=(
'Dict of quantiy name: value such that search results should have values '
'greater than value'
)),
'$gte': fields.Nested(api.model('Querygte', query_model_fields), allow_null=True, skip_none=True, description=(
'Dict of quantiy name: value such that search results should have values '
'greater than or equal to value'
)),
}})
search_model_fields = {
'query': fields.Nested(query_model, allow_null=True, skip_none=True),
'pagination': fields.Nested(pagination_model, allow_null=True, skip_none=True),
'scroll': fields.Nested(scroll_model, allow_null=True, skip_none=True),
'aggregation': fields.Nested(aggregation_model, allow_null=True),
......@@ -94,20 +138,6 @@ search_model_fields = {
search_model = api.model('Search', search_model_fields)
query_model_fields = {
qualified_name: quantity.flask_field
for qualified_name, quantity in search.search_quantities.items()}
query_model_fields.update(**{
'owner': fields.String(description='The group the calculations belong to.', allow_null=True, skip_none=True),
'domain': fields.String(description='Specify the domain to search in: %s, default is ``%s``' % (
', '.join(['``%s``' % domain for domain in datamodel.domains]), config.meta.default_domain)),
'from_time': fields.Raw(description='The minimum entry time.', allow_null=True, skip_none=True),
'until_time': fields.Raw(description='The maximum entry time.', allow_null=True, skip_none=True)
})
query_model = api.model('Query', query_model_fields)
def add_pagination_parameters(request_parser):
''' Add pagination parameters to Flask querystring parser. '''
......
......@@ -35,7 +35,8 @@ from nomad.app import common
from .api import api
from .auth import authenticate, create_authorization_predicate
from .common import streamed_zipfile, add_search_parameters, apply_search_parameters
from .common import streamed_zipfile, add_search_parameters, apply_search_parameters,\
search_model
ns = api.namespace('raw', description='Downloading raw data files.')
......@@ -391,6 +392,15 @@ _raw_file_from_query_parser.add_argument(
location='args', action='append')
_raw_file_from_query_model = api.inherit('ArchiveSearch', search_model, {
'compress': fields.Boolean(description='Use compression on .zip files, default is not.', default=False),
'strip': fields.Boolean(description='Removes a potential common path prefix from all file paths.', default=False),
'file_pattern': fields.List(fields.String, description=(
'A wildcard pattern. Only filenames that match this pattern will be in the '
'download. Multiple patterns will be combined with logical or'), allow_null=True, skip_none=True)
})
@ns.route('/query')
class RawFileQueryResource(Resource):
manifest_quantities = ['upload_id', 'calc_id', 'external_id', 'raw_id', 'pid', 'calc_hash']
......@@ -527,6 +537,143 @@ class RawFileQueryResource(Resource):
return streamed_zipfile(
generator(), zipfile_name='nomad_raw_files.zip', compress=compress)
@api.doc('post_raw_files_from_query')
@api.expect(_raw_file_from_query_model)
@api.response(400, 'Invalid requests, e.g. wrong owner type or bad search parameters')
@api.response(200, 'File(s) send', headers={'Content-Type': 'application/zip'})
@authenticate(signature_token=True)
def post(self):
''' Download a .zip file with all raw-files for all entries that match the given
search parameters.
See ``/repo`` endpoint for documentation on the search
parameters.
Zip files are streamed; instead of 401 errors, the zip file will just not contain
any files that the user is not authorized to access.
The zip file will contain a ``manifest.json`` with the repository meta data.
'''
patterns: List[str] = None
try:
data_in = request.get_json()
compress = data_in.get('compress', False)
strip = data_in.get('strip', False)
pattern = data_in.get('file_pattern', None)
if isinstance(pattern, str):
patterns = [pattern]
elif pattern is None:
patterns = []
else:
patterns = pattern
query = data_in.get('query', {})
query_expression = {key: val for key, val in query.items() if '$' in key}
except Exception:
abort(400, message='bad parameter types')
logger = common.logger.bind(query=urllib.parse.urlencode(query, doseq=True))
search_request = search.SearchRequest()
apply_search_parameters(search_request, query)
search_request.include('calc_id', 'upload_id', 'mainfile')
if query_expression:
search_request.query_expression(query_expression)
def path(entry):
return '%s/%s' % (entry['upload_id'], entry['mainfile'])
calcs = search_request.execute_scan(
order_by='upload_id',
size=config.services.download_scan_size,
scroll=config.services.download_scan_timeout)
if strip:
if search_request.execute()['total'] > config.raw_file_strip_cutoff:
abort(400, 'The requested download has to many files for using "strip".')
calcs = list(calcs)
paths = [path(entry) for entry in calcs]
common_prefix_len = len(utils.common_prefix(paths))
else:
common_prefix_len = 0
def generator():
try:
manifest = {}
directories = set()
upload_files = None
streamed, skipped = 0, 0
for entry in calcs:
upload_id = entry['upload_id']
mainfile = entry['mainfile']
if upload_files is None or upload_files.upload_id != upload_id:
logger.info('opening next upload for raw file streaming', upload_id=upload_id)
if upload_files is not None:
upload_files.close()
upload_files = UploadFiles.get(upload_id)
if upload_files is None:
logger.error('upload files do not exist', upload_id=upload_id)
continue
def open_file(upload_filename):
return upload_files.raw_file(upload_filename, 'rb')
upload_files._is_authorized = create_authorization_predicate(
upload_id=upload_id, calc_id=entry['calc_id'])
directory = os.path.dirname(mainfile)
directory_w_upload = os.path.join(upload_files.upload_id, directory)
if directory_w_upload not in directories:
streamed += 1
directories.add(directory_w_upload)
for filename, file_size in upload_files.raw_file_list(directory=directory):
filename = os.path.join(directory, filename)
filename_w_upload = os.path.join(upload_files.upload_id, filename)
filename_wo_prefix = filename_w_upload[common_prefix_len:]
if len(patterns) == 0 or any(
fnmatch.fnmatchcase(os.path.basename(filename_wo_prefix), pattern)
for pattern in patterns):
yield (
filename_wo_prefix, filename, open_file,
lambda *args, **kwargs: file_size)
else:
skipped += 1
if (streamed + skipped) % 10000 == 0:
logger.info('streaming raw files', streamed=streamed, skipped=skipped)
manifest[path(entry)] = {
key: entry[key]
for key in RawFileQueryResource.manifest_quantities
if entry.get(key) is not None
}
if upload_files is not None:
upload_files.close()
logger.info('streaming raw file manifest')
try:
manifest_contents = json.dumps(manifest).encode('utf-8')
except Exception as e:
manifest_contents = json.dumps(
dict(error='Could not create the manifest: %s' % (e))).encode('utf-8')
logger.error('could not create raw query manifest', exc_info=e)
yield (
'manifest.json', 'manifest',
lambda *args: BytesIO(manifest_contents),
lambda *args: len(manifest_contents))
except Exception as e:
logger.warning(
'unexpected error while streaming raw data from query', exc_info=e)
logger.info('start streaming raw files')
return streamed_zipfile(
generator(), zipfile_name='nomad_raw_files.zip', compress=compress)
def respond_to_get_raw_files(upload_id, files, compress=False, strip=False):
upload_files = UploadFiles.get(
......
......@@ -117,6 +117,16 @@ for qualified_name, quantity in search_extension.search_quantities.items():
_repo_calcs_model_fields[qualified_name] = fields.Raw(
description=quantity.description, allow_null=True, skip_none=True)
_repo_calcs_model_fields.update(**{
'date_histogram': fields.Boolean(default=False, description='Add an additional aggregation over the upload time', allow_null=True, skip_none=True),
'interval': fields.String(description='Interval to use for upload time aggregation.', allow_null=True, skip_none=True),
'metrics': fields.List(fields.String, description=(
'Metrics to aggregate over all quantities and their values as comma separated list. '
'Possible values are %s.' % ', '.join(search_extension.metrics.keys())), allow_null=True, skip_none=True),
'statistics_required': fields.List(fields.String, description='Quantities for which to aggregate values and their metrics.', allow_null=True, skip_none=True),
'exclude': fields.List(fields.String, description='Excludes the given keys in the returned data.', allow_null=True, skip_none=True)
})
_repo_calcs_model = api.inherit('RepoCalculations', search_model, _repo_calcs_model_fields)
......@@ -263,6 +273,151 @@ class RepoCalcsResource(Resource):
traceback.print_exc()
abort(400, str(e))
@api.doc('post_search')
@api.response(400, 'Invalid requests, e.g. wrong owner type or bad search parameters')
@api.expect(_repo_calcs_model)
@api.marshal_with(_repo_calcs_model, skip_none=True, code=200, description='Search results send')
@authenticate()
def post(self):
'''
Search for calculations in the repository form, paginated.
The ``owner`` parameter determines the overall entries to search through.
Possible values are: ``all`` (show all entries visible to the current user), ``public``
(show all publically visible entries), ``user`` (show all user entries, requires login),
``staging`` (show all user entries in staging area, requires login).
You can use the various quantities to search/filter for. For some of the
indexed quantities this endpoint returns aggregation information. This means
you will be given a list of all possible values and the number of entries
that have the certain value. You can also use these aggregations on an empty
search to determine the possible values.
The pagination parameters allows determine which page to return via the
``page`` and ``per_page`` parameters. Pagination however, is limited to the first
100k (depending on ES configuration) hits.
An alternative to pagination is to use ``scroll`` and ``scroll_id``. With ``scroll``
you will get a ``scroll_id`` on the first request. Each call with ``scroll`` and
the respective ``scroll_id`` will return the next ``per_page`` (here the default is 1000)
results. Scroll however, ignores ordering and does not return aggregations.
The scroll view used in the background will stay alive for 1 minute between requests.
If the given ``scroll_id`` is not available anymore, a HTTP 400 is raised.
The search will return aggregations on a predefined set of quantities. Aggregations
will tell you what quantity values exist and how many entries match those values.
Ordering is determined by ``order_by`` and ``order`` parameters. Default is
``upload_time`` in decending order.
'''
try:
data_in = request.get_json()
Scroll = data_in.get('scroll', {})
scroll = Scroll.get('scroll', False)
scroll_id = Scroll.get('scroll_id', None)
pagination = data_in.get('pagination', {})
page = pagination.get('page', 1)
per_page = pagination.get('per_page', 10 if not scroll else 1000)
order = pagination.get('order', -1)
order_by = pagination.get('order_by', 'upload_time')
date_histogram = data_in.get('date_histogram', False)
interval = data_in.get('interval', '1M')
metrics: List[str] = data_in.get('metrics', [])
statistics = data_in.get('statistics_required', [])
query = data_in.get('query', {})
query_expression = {key: val for key, val in query.items() if '$' in key}
except Exception as e:
abort(400, message='bad parameters: %s' % str(e))
for metric in metrics:
if metric not in search_extension.metrics:
abort(400, message='there is no metric %s' % metric)
search_request = search.SearchRequest()
apply_search_parameters(search_request, query)
if date_histogram:
search_request.date_histogram(interval=interval, metrics_to_use=metrics)
if query_expression:
search_request.query_expression(query_expression)
try:
assert page >= 1
assert per_page >= 0
except AssertionError:
abort(400, message='invalid pagination')
if order not in [-1, 1]:
abort(400, message='invalid pagination')
if len(statistics) > 0:
search_request.statistics(statistics, metrics_to_use=metrics)
group_metrics = [
group_quantity.metric_name
for group_name, group_quantity in search_extension.groups.items()
if data_in.get(group_name, False)]
total_metrics = metrics + group_metrics
if len(total_metrics) > 0:
search_request.totals(metrics_to_use=total_metrics)
if 'exclude' in data_in:
excludes = data_in['exclude']
if excludes is not None:
search_request.exclude(*excludes)
try:
if scroll:
results = search_request.execute_scrolled(scroll_id=scroll_id, size=per_page)
else:
for group_name, group_quantity in search_extension.groups.items():
if data_in.get(group_name, False):
kwargs: Dict[str, Any] = {}
if group_name == 'uploads_grouped':
kwargs.update(order_by='upload_time', order='desc')
search_request.quantity(
group_quantity.qualified_name, size=per_page, examples=1,
after=data_in.get('%s_after' % group_name, None),
**kwargs)
results = search_request.execute_paginated(
per_page=per_page, page=page, order=order, order_by=order_by)
# TODO just a work around to make things prettier
if 'statistics' in results:
statistics = results['statistics']
if 'code_name' in statistics and 'currupted mainfile' in statistics['code_name']:
del(statistics['code_name']['currupted mainfile'])
if 'quantities' in results:
quantities = results.pop('quantities')
for group_name, group_quantity in search_extension.groups.items():
if data_in.get(group_name, False):
results[group_name] = quantities[group_quantity.qualified_name]
# build python code/curl snippet
code_args = dict(data_in)
if 'statistics' in code_args:
del(code_args['statistics'])
results['code'] = {
'curl': query_api_curl('archive', 'query', query_string=code_args),
'python': query_api_python('archive', 'query', query_string=code_args),
'clientlib': query_api_clientlib(**code_args)
}
return results, 200
except search.ScrollIdNotFound:
abort(400, 'The given scroll_id does not exist.')
except KeyError as e:
import traceback
traceback.print_exc()
abort(400, str(e))
_query_model_parameters = {
'owner': fields.String(description='Specify which calcs to return: ``all``, ``public``, ``user``, ``staging``, default is ``all``'),
......
......@@ -243,6 +243,73 @@ class SearchRequest:
return self
def query_expression(self, expression):
bool_operators = ['$and', '$or', '$not']
comp_operators = ['$gt', '$lt', '$gte', '$lte']
def _gen_query(name, value, operator):
quantity = search_quantities[name]
if operator in bool_operators:
value = value if isinstance(value, list) else [value]
value = quantity.derived(value) if quantity.derived else value
q = [Q('match', **{quantity.search_field: item}) for item in value]
q = _add_queries(q, '$and')
elif operator in comp_operators:
q = Q('range', **{quantity.search_field: {operator.lstrip('$'): value}})
else:
raise ValueError('Invalid operator %s' % operator)
return q
def _add_queries(queries, operator):
if operator == '$and':
q = Q('bool', must=queries)
elif operator == '$or':
q = Q('bool', should=queries)
elif operator == '$not':
q = Q('bool', must_not=queries)
elif operator in comp_operators:
q = Q('bool', must=queries)
elif operator is None:
q = queries[0]
else:
raise ValueError('Invalid operator %s' % operator)
return q
def _query(exp_value, exp_op=None):
if isinstance(
exp_value, dict) and len(exp_value) == 1 and '$' not in list(
exp_value.keys())[-1]:
key, val = list(exp_value.items())[0]
query = _gen_query(key, val, exp_op)
else:
q = []
if isinstance(exp_value, dict):
for key, val in exp_value.items():
q.append(_query(val, exp_op=key))
elif isinstance(exp_value, list):
for val in exp_value:
op = exp_op
if isinstance(val, dict):
k, v = list(val.items())[0]
if k[0] == '$':
val, op = v, k
q.append(_query(val, exp_op=op))
query = _add_queries(q, exp_op)
return query
q = _query(expression)
self.q &= q
def time_range(self, start: datetime, end: datetime):
''' Adds a time range to the query. '''
if start is None and end is None:
......
......@@ -740,6 +740,53 @@ class TestArchive(UploadFilesBasedTests):
assert count > 0
@pytest.mark.timeout(config.tests.default_timeout)
@pytest.fixture(scope='function')
def example_upload(self, proc_infra, test_user):
path = 'tests/data/proc/example_vasp_with_binary.zip'
results = []
for uid in range(2):
upload_id = 'vasp_%d' % uid
processed = test_processing.run_processing((upload_id, path), test_user)
processed.publish_upload()
try:
processed.block_until_complete(interval=.01)
except Exception:
pass
results.append(processed)
return results
@pytest.mark.parametrize('query_expression, nresults', [
({
'$and': [
{'dft.system': 'bulk'}, {'$not': [{'dft.compound_type': 'ternary'}]}
]
}, 2),
({
'$or': [
{'upload_id': ['vasp_0']}, {'$gte': {'n_atoms': 1}}
]
}, 4),
({
'$not': [{'dft.spacegroup': 221}, {'dft.spacegroup': 227}]
}, 0),
])
def test_post_archive_query(self, api, example_upload, query_expression, nresults):
data = {'pagination': {'per_page': 5}, 'query': query_expression}
uri = '/archive/query'
rv = api.post(uri, content_type='application/json', data=json.dumps(data))
assert rv.status_code == 200
data = rv.get_json()
assert data
results = data.get('results', None)
assert len(results) == nresults
class TestMetainfo():
@pytest.mark.parametrize('package', ['common', 'vasp', 'general.experimental', 'eels'])
......@@ -1085,6 +1132,17 @@ class TestRepo():
assert results is not None
assert len(results) == n_results
def test_post_search_query(self, api, example_elastic_calcs, no_warn):
query_expression = {'$not': [{'dft.system': 'bulk'}]}
data = {
'pagination': {'page': 1, 'per_page': 5}, 'query': query_expression,
'statistics_required': ['dft.system']}
rv = api.post('/repo/', content_type='application/json', data=json.dumps(data))
assert rv.status_code == 200
data = json.loads(rv.data)
results = data.get('results', None)
assert(len(results) == 0)
@pytest.mark.parametrize('first, order_by, order', [
('1', 'formula', -1), ('2', 'formula', 1),
('2', 'dft.basis_set', -1), ('1', 'dft.basis_set', 1),
......@@ -1730,6 +1788,18 @@ class TestRaw(UploadFilesBasedTests):
assert rv.status_code == 200