diff --git a/nomad/app/api/archive.py b/nomad/app/api/archive.py index b62ce5f1c902a695d71e9ade9d48e85bd5840f23..00ecaa1fe6fdf05457332531d96bb93540d67186 100644 --- a/nomad/app/api/archive.py +++ b/nomad/app/api/archive.py @@ -284,7 +284,10 @@ class ArchiveQueryResource(Resource): 'calc_id', 'upload_id', 'with_embargo', 'published', 'parser_name') if query_expression: - search_request.query_expression(query_expression) + try: + search_request.query_expression(query_expression) + except AssertionError as e: + abort(400, str(e)) try: if aggregation: diff --git a/nomad/app/api/common.py b/nomad/app/api/common.py index 5e6ddc0f1e85d49d9e9e12856a734822893a1b42..be7ea18962d1ac843da27b2cfb33ef56ce6a18f2 100644 --- a/nomad/app/api/common.py +++ b/nomad/app/api/common.py @@ -342,7 +342,7 @@ def streamed_zipfile( return response -def query_api_url(*args, query_string: Dict[str, Any] = None): +def _query_api_url(*args, query: Dict[str, Any] = None): ''' Creates a API URL. Arguments: @@ -350,43 +350,61 @@ def query_api_url(*args, query_string: Dict[str, Any] = None): query_string: A dict with query string parameters ''' url = os.path.join(config.api_url(False), *args) - if query_string is not None: - url = '%s?%s' % (url, urlencode(query_string, doseq=True)) + if query is not None: + url = '%s?%s' % (url, urlencode(query, doseq=True)) return url -def query_api_python(*args, **kwargs): +def query_api_python(query): ''' Creates a string of python code to execute a search query to the repository using the requests library. ''' - url = query_api_url(*args, **kwargs) + query = _filter_api_query(query) + url = _query_api_url('archive', 'query') return '''import requests -response = requests.post("{}") -data = response.json()'''.format(url) - - -def query_api_clientlib(**kwargs): - ''' - Creates a string of python code to execute a search query on the archive using - the client library. - ''' +response = requests.post('{}', json={{ + 'query': {{ +{} + }} +}}) +data = response.json()'''.format( + url, + ',\n'.join([ + ' \'%s\': %s' % (key, pprint.pformat(value, compact=True)) + for key, value in query.items()]) + ) + + +def _filter_api_query(query): def normalize_value(key, value): quantity = search.search_quantities.get(key) if quantity.many and not isinstance(value, list): return [value] + elif isinstance(value, list) and len(value) == 1: + return value[0] return value - query = { - key: normalize_value(key, value) for key, value in kwargs.items() + result = { + key: normalize_value(key, value) for key, value in query.items() if key in search.search_quantities and (key != 'domain' or value != config.meta.default_domain) } for key in ['dft.optimade']: - if key in kwargs: - query[key] = kwargs[key] + if key in query: + result[key] = query[key] + + return result + + +def query_api_clientlib(**kwargs): + ''' + Creates a string of python code to execute a search query on the archive using + the client library. + ''' + query = _filter_api_query(kwargs) out = io.StringIO() out.write('from nomad import client, config\n') @@ -401,12 +419,13 @@ def query_api_clientlib(**kwargs): return out.getvalue() -def query_api_curl(*args, **kwargs): +def query_api_curl(query): ''' - Creates a string of curl command to execute a search query to the repository. + Creates a string of curl command to execute a search query and download the respective + archives in a .zip file. ''' - url = query_api_url(*args, **kwargs) - return 'curl -X POST %s -H "accept: application/json" --output "nomad.json"' % url + url = _query_api_url('archive', 'download', query=query) + return 'curl "%s" --output nomad.zip' % url def enable_gzip(level: int = 1, min_size: int = 1024): diff --git a/nomad/app/api/raw.py b/nomad/app/api/raw.py index c49cdf6fb11513cd45f3e61b0d9c1668c5632cb9..ef81074635f1e69601a420e732737adcc6ff47bc 100644 --- a/nomad/app/api/raw.py +++ b/nomad/app/api/raw.py @@ -578,7 +578,10 @@ class RawFileQueryResource(Resource): search_request.include('calc_id', 'upload_id', 'mainfile') if query_expression: - search_request.query_expression(query_expression) + try: + search_request.query_expression(query_expression) + except AssertionError as e: + abort(400, str(e)) def path(entry): return '%s/%s' % (entry['upload_id'], entry['mainfile']) diff --git a/nomad/app/api/repo.py b/nomad/app/api/repo.py index 4017a5d6bf7d91e95c46b7cb8e2f00bb6c44beed..323215150b07cc699cabdf753fda12fa5d5812cb 100644 --- a/nomad/app/api/repo.py +++ b/nomad/app/api/repo.py @@ -68,8 +68,8 @@ class RepoCalcResource(Resource): result = calc.to_dict() result['code'] = { - 'python': query_api_python('archive', upload_id, calc_id), - 'curl': query_api_curl('archive', upload_id, calc_id), + 'python': query_api_python(dict(upload_id=upload_id, calc_id=calc_id)), + 'curl': query_api_curl(dict(upload_id=upload_id, calc_id=calc_id)), 'clientlib': query_api_clientlib(upload_id=[upload_id], calc_id=[calc_id]) } @@ -256,12 +256,12 @@ class RepoCalcsResource(Resource): results[group_name] = quantities[group_quantity.qualified_name] # build python code/curl snippet - code_args = dict(request.args) + code_args = request.args.to_dict(flat=False) if 'statistics' in code_args: del(code_args['statistics']) results['code'] = { - 'curl': query_api_curl('archive', 'query', query_string=code_args), - 'python': query_api_python('archive', 'query', query_string=code_args), + 'curl': query_api_curl(code_args), + 'python': query_api_python(code_args), 'clientlib': query_api_clientlib(**code_args) } @@ -342,7 +342,10 @@ class RepoCalcsResource(Resource): search_request.date_histogram(interval=interval, metrics_to_use=metrics) if query_expression: - search_request.query_expression(query_expression) + try: + search_request.query_expression(query_expression) + except AssertionError as e: + abort(400, str(e)) try: assert page >= 1 @@ -405,8 +408,8 @@ class RepoCalcsResource(Resource): if 'statistics' in code_args: del(code_args['statistics']) results['code'] = { - 'curl': query_api_curl('archive', 'query', query_string=code_args), - 'python': query_api_python('archive', 'query', query_string=code_args), + 'curl': query_api_curl(code_args), + 'python': query_api_python(code_args), 'clientlib': query_api_clientlib(**code_args) } diff --git a/nomad/search.py b/nomad/search.py index a3dfa70b7c661d99d2a456be7f36d7b5846620c4..6b8ecd521e009eafd47b29a93cafe9dbc98753ff 100644 --- a/nomad/search.py +++ b/nomad/search.py @@ -215,15 +215,14 @@ class SearchRequest: return self - def search_parameter(self, name, value): + def _search_parameter_to_es(self, name, value): quantity = search_quantities[name] if quantity.many and not isinstance(value, list): value = [value] if quantity.many_or and isinstance(value, List): - self.q &= Q('terms', **{quantity.search_field: value}) - return self + return Q('terms', **{quantity.search_field: value}) if quantity.derived: if quantity.many and not isinstance(value, list): @@ -235,9 +234,12 @@ class SearchRequest: else: values = [value] - for item in values: - self.q &= Q('match', **{quantity.search_field: item}) + return Q('bool', must=[ + Q('match', **{quantity.search_field: item}) + for item in values]) + def search_parameter(self, name, value): + self.q &= self._search_parameter_to_es(name, value) return self def query(self, query): @@ -246,74 +248,41 @@ class SearchRequest: return self - def query_expression(self, expression): - - bool_operators = ['$and', '$or', '$not'] - comp_operators = ['$gt', '$lt', '$gte', '$lte'] - - def _gen_query(name, value, operator): - quantity = search_quantities.get(name) - if quantity is None: - raise InvalidQuery('Search quantity %s does not exist' % name) - - if operator in bool_operators: - value = value if isinstance(value, list) else [value] - value = quantity.derived(value) if quantity.derived else value - q = [Q('match', **{quantity.search_field: item}) for item in value] - q = _add_queries(q, '$and') - elif operator in comp_operators: - q = Q('range', **{quantity.search_field: {operator.lstrip('$'): value}}) - else: - raise ValueError('Invalid operator %s' % operator) - - return q - - def _add_queries(queries, operator): - if operator == '$and': - q = Q('bool', must=queries) - elif operator == '$or': - q = Q('bool', should=queries) - elif operator == '$not': - q = Q('bool', must_not=queries) - elif operator in comp_operators: - q = Q('bool', must=queries) - elif operator is None: - q = queries[0] - else: - raise ValueError('Invalid operator %s' % operator) - - return q + def query_expression(self, expression) -> 'SearchRequest': + + bool_operators = {'$and': 'must', '$or': 'should', '$not': 'must_not'} + comp_operators = {'$%s' % op: op for op in ['gt', 'gte', 'lt', 'lte']} + + def _to_es(key, values): + if key in bool_operators: + if isinstance(values, dict): + values = [values] + assert isinstance(values, list), 'bool operator requires a list of dicts or dict' + child_es_queries = [ + _to_es(child_key, child_value) + for child_query in values + for child_key, child_value in child_query.items()] + return Q('bool', **{bool_operators[key]: child_es_queries}) + + if key in comp_operators: + assert isinstance(values, dict), 'comparison operator requires a dict' + assert len(values) == 1, 'comparison operator requires exactly one quantity' + quantity_name, value = next(iter(values.items())) + quantity = search_quantities.get(quantity_name) + assert quantity is not None, 'quantity %s does not exist' % quantity_name + return Q('range', **{quantity.search_field: {comp_operators[key]: value}}) - def _query(exp_value, exp_op=None): - if isinstance( - exp_value, dict) and len(exp_value) == 1 and '$' not in list( - exp_value.keys())[-1]: - key, val = list(exp_value.items())[0] - query = _gen_query(key, val, exp_op) - - else: - q = [] - if isinstance(exp_value, dict): - for key, val in exp_value.items(): - q.append(_query(val, exp_op=key)) - - elif isinstance(exp_value, list): - for val in exp_value: - op = exp_op - - if isinstance(val, dict): - k, v = list(val.items())[0] - if k[0] == '$': - val, op = v, k - - q.append(_query(val, exp_op=op)) - - query = _add_queries(q, exp_op) + try: + return self._search_parameter_to_es(key, values) + except KeyError: + assert False, 'quantity %s does not exist' % key - return query + if len(expression) == 0: + self.q &= Q() + else: + self.q &= Q('bool', must=[_to_es(key, value) for key, value in expression.items()]) - q = _query(expression) - self.q &= q + return self def time_range(self, start: datetime, end: datetime): ''' Adds a time range to the query. ''' diff --git a/tests/app/test_api.py b/tests/app/test_api.py index 5927ddb803df6a849d8d575395ec6e0c7ff966d3..3294dcf5d652b4f35e772fa837ae8391b54d1938 100644 --- a/tests/app/test_api.py +++ b/tests/app/test_api.py @@ -761,25 +761,32 @@ class TestArchive(UploadFilesBasedTests): return results @pytest.mark.parametrize('query_expression, nresults', [ - ({ + pytest.param({}, 4, id='empty'), + pytest.param({'dft.system': 'bulk'}, 4, id='match'), + pytest.param({'$gte': {'n_atoms': 1}}, 4, id='comparison'), + pytest.param({ '$and': [ {'dft.system': 'bulk'}, {'$not': [{'dft.compound_type': 'ternary'}]} ] - }, 2), - ({ + }, 2, id="and-with-not"), + pytest.param({ '$or': [ {'upload_id': ['vasp_0']}, {'$gte': {'n_atoms': 1}} ] - }, 4), - ({ + }, 4, id="or-with-gte"), + pytest.param({ '$not': [{'dft.spacegroup': 221}, {'dft.spacegroup': 227}] - }, 0), + }, 0, id="not"), + pytest.param({ + '$and': [ + {'dft.code_name': 'VASP'}, + {'$gte': {'n_atoms': 3}}, + {'$lte': {'dft.workflow.section_relaxation.final_energy_difference': 1e-24}} + ]}, 0, id='client-example') ]) def test_post_archive_query(self, api, example_upload, query_expression, nresults): data = {'pagination': {'per_page': 5}, 'query': query_expression} - - uri = '/archive/query' - rv = api.post(uri, content_type='application/json', data=json.dumps(data)) + rv = api.post('/archive/query', content_type='application/json', data=json.dumps(data)) assert rv.status_code == 200 data = rv.get_json() @@ -788,6 +795,16 @@ class TestArchive(UploadFilesBasedTests): results = data.get('results', None) assert len(results) == nresults + @pytest.mark.parametrize('query', [ + pytest.param({'$bad_op': {'n_atoms': 1}}, id='bad-op') + ]) + def test_post_archive_bad_query(self, api, query): + rv = api.post( + '/archive/query', content_type='application/json', + data=json.dumps(dict(query=query))) + + assert rv.status_code == 400 + class TestMetainfo(): @pytest.mark.parametrize('package', ['common', 'vasp'])