Commit e01fe8c1 authored by Markus Scheidgen's avatar Markus Scheidgen
Browse files

Fixed issues with query expressions. Fixed issues with generated api code. #410

parent 92696cd7
Pipeline #81811 canceled with stages
in 13 minutes and 59 seconds
...@@ -284,7 +284,10 @@ class ArchiveQueryResource(Resource): ...@@ -284,7 +284,10 @@ class ArchiveQueryResource(Resource):
'calc_id', 'upload_id', 'with_embargo', 'published', 'parser_name') 'calc_id', 'upload_id', 'with_embargo', 'published', 'parser_name')
if query_expression: if query_expression:
search_request.query_expression(query_expression) try:
search_request.query_expression(query_expression)
except AssertionError as e:
abort(400, str(e))
try: try:
if aggregation: if aggregation:
......
...@@ -342,7 +342,7 @@ def streamed_zipfile( ...@@ -342,7 +342,7 @@ def streamed_zipfile(
return response return response
def query_api_url(*args, query_string: Dict[str, Any] = None): def _query_api_url(*args, query: Dict[str, Any] = None):
''' '''
Creates a API URL. Creates a API URL.
Arguments: Arguments:
...@@ -350,43 +350,61 @@ def query_api_url(*args, query_string: Dict[str, Any] = None): ...@@ -350,43 +350,61 @@ def query_api_url(*args, query_string: Dict[str, Any] = None):
query_string: A dict with query string parameters query_string: A dict with query string parameters
''' '''
url = os.path.join(config.api_url(False), *args) url = os.path.join(config.api_url(False), *args)
if query_string is not None: if query is not None:
url = '%s?%s' % (url, urlencode(query_string, doseq=True)) url = '%s?%s' % (url, urlencode(query, doseq=True))
return url return url
def query_api_python(*args, **kwargs): def query_api_python(query):
''' '''
Creates a string of python code to execute a search query to the repository using Creates a string of python code to execute a search query to the repository using
the requests library. the requests library.
''' '''
url = query_api_url(*args, **kwargs) query = _filter_api_query(query)
url = _query_api_url('archive', 'query')
return '''import requests return '''import requests
response = requests.post("{}") response = requests.post('{}', json={{
data = response.json()'''.format(url) 'query': {{
{}
}}
def query_api_clientlib(**kwargs): }})
''' data = response.json()'''.format(
Creates a string of python code to execute a search query on the archive using url,
the client library. ',\n'.join([
''' ' \'%s\': %s' % (key, pprint.pformat(value, compact=True))
for key, value in query.items()])
)
def _filter_api_query(query):
def normalize_value(key, value): def normalize_value(key, value):
quantity = search.search_quantities.get(key) quantity = search.search_quantities.get(key)
if quantity.many and not isinstance(value, list): if quantity.many and not isinstance(value, list):
return [value] return [value]
elif isinstance(value, list) and len(value) == 1:
return value[0]
return value return value
query = { result = {
key: normalize_value(key, value) for key, value in kwargs.items() key: normalize_value(key, value) for key, value in query.items()
if key in search.search_quantities and (key != 'domain' or value != config.meta.default_domain) if key in search.search_quantities and (key != 'domain' or value != config.meta.default_domain)
} }
for key in ['dft.optimade']: for key in ['dft.optimade']:
if key in kwargs: if key in query:
query[key] = kwargs[key] result[key] = query[key]
return result
def query_api_clientlib(**kwargs):
'''
Creates a string of python code to execute a search query on the archive using
the client library.
'''
query = _filter_api_query(kwargs)
out = io.StringIO() out = io.StringIO()
out.write('from nomad import client, config\n') out.write('from nomad import client, config\n')
...@@ -401,12 +419,13 @@ def query_api_clientlib(**kwargs): ...@@ -401,12 +419,13 @@ def query_api_clientlib(**kwargs):
return out.getvalue() return out.getvalue()
def query_api_curl(*args, **kwargs): def query_api_curl(query):
''' '''
Creates a string of curl command to execute a search query to the repository. Creates a string of curl command to execute a search query and download the respective
archives in a .zip file.
''' '''
url = query_api_url(*args, **kwargs) url = _query_api_url('archive', 'download', query=query)
return 'curl -X POST %s -H "accept: application/json" --output "nomad.json"' % url return 'curl "%s" --output nomad.zip' % url
def enable_gzip(level: int = 1, min_size: int = 1024): def enable_gzip(level: int = 1, min_size: int = 1024):
......
...@@ -578,7 +578,10 @@ class RawFileQueryResource(Resource): ...@@ -578,7 +578,10 @@ class RawFileQueryResource(Resource):
search_request.include('calc_id', 'upload_id', 'mainfile') search_request.include('calc_id', 'upload_id', 'mainfile')
if query_expression: if query_expression:
search_request.query_expression(query_expression) try:
search_request.query_expression(query_expression)
except AssertionError as e:
abort(400, str(e))
def path(entry): def path(entry):
return '%s/%s' % (entry['upload_id'], entry['mainfile']) return '%s/%s' % (entry['upload_id'], entry['mainfile'])
......
...@@ -68,8 +68,8 @@ class RepoCalcResource(Resource): ...@@ -68,8 +68,8 @@ class RepoCalcResource(Resource):
result = calc.to_dict() result = calc.to_dict()
result['code'] = { result['code'] = {
'python': query_api_python('archive', upload_id, calc_id), 'python': query_api_python(dict(upload_id=upload_id, calc_id=calc_id)),
'curl': query_api_curl('archive', upload_id, calc_id), 'curl': query_api_curl(dict(upload_id=upload_id, calc_id=calc_id)),
'clientlib': query_api_clientlib(upload_id=[upload_id], calc_id=[calc_id]) 'clientlib': query_api_clientlib(upload_id=[upload_id], calc_id=[calc_id])
} }
...@@ -256,12 +256,12 @@ class RepoCalcsResource(Resource): ...@@ -256,12 +256,12 @@ class RepoCalcsResource(Resource):
results[group_name] = quantities[group_quantity.qualified_name] results[group_name] = quantities[group_quantity.qualified_name]
# build python code/curl snippet # build python code/curl snippet
code_args = dict(request.args) code_args = request.args.to_dict(flat=False)
if 'statistics' in code_args: if 'statistics' in code_args:
del(code_args['statistics']) del(code_args['statistics'])
results['code'] = { results['code'] = {
'curl': query_api_curl('archive', 'query', query_string=code_args), 'curl': query_api_curl(code_args),
'python': query_api_python('archive', 'query', query_string=code_args), 'python': query_api_python(code_args),
'clientlib': query_api_clientlib(**code_args) 'clientlib': query_api_clientlib(**code_args)
} }
...@@ -342,7 +342,10 @@ class RepoCalcsResource(Resource): ...@@ -342,7 +342,10 @@ class RepoCalcsResource(Resource):
search_request.date_histogram(interval=interval, metrics_to_use=metrics) search_request.date_histogram(interval=interval, metrics_to_use=metrics)
if query_expression: if query_expression:
search_request.query_expression(query_expression) try:
search_request.query_expression(query_expression)
except AssertionError as e:
abort(400, str(e))
try: try:
assert page >= 1 assert page >= 1
...@@ -405,8 +408,8 @@ class RepoCalcsResource(Resource): ...@@ -405,8 +408,8 @@ class RepoCalcsResource(Resource):
if 'statistics' in code_args: if 'statistics' in code_args:
del(code_args['statistics']) del(code_args['statistics'])
results['code'] = { results['code'] = {
'curl': query_api_curl('archive', 'query', query_string=code_args), 'curl': query_api_curl(code_args),
'python': query_api_python('archive', 'query', query_string=code_args), 'python': query_api_python(code_args),
'clientlib': query_api_clientlib(**code_args) 'clientlib': query_api_clientlib(**code_args)
} }
......
...@@ -215,15 +215,14 @@ class SearchRequest: ...@@ -215,15 +215,14 @@ class SearchRequest:
return self return self
def search_parameter(self, name, value): def _search_parameter_to_es(self, name, value):
quantity = search_quantities[name] quantity = search_quantities[name]
if quantity.many and not isinstance(value, list): if quantity.many and not isinstance(value, list):
value = [value] value = [value]
if quantity.many_or and isinstance(value, List): if quantity.many_or and isinstance(value, List):
self.q &= Q('terms', **{quantity.search_field: value}) return Q('terms', **{quantity.search_field: value})
return self
if quantity.derived: if quantity.derived:
if quantity.many and not isinstance(value, list): if quantity.many and not isinstance(value, list):
...@@ -235,9 +234,12 @@ class SearchRequest: ...@@ -235,9 +234,12 @@ class SearchRequest:
else: else:
values = [value] values = [value]
for item in values: return Q('bool', must=[
self.q &= Q('match', **{quantity.search_field: item}) Q('match', **{quantity.search_field: item})
for item in values])
def search_parameter(self, name, value):
self.q &= self._search_parameter_to_es(name, value)
return self return self
def query(self, query): def query(self, query):
...@@ -246,74 +248,41 @@ class SearchRequest: ...@@ -246,74 +248,41 @@ class SearchRequest:
return self return self
def query_expression(self, expression): def query_expression(self, expression) -> 'SearchRequest':
bool_operators = ['$and', '$or', '$not'] bool_operators = {'$and': 'must', '$or': 'should', '$not': 'must_not'}
comp_operators = ['$gt', '$lt', '$gte', '$lte'] comp_operators = {'$%s' % op: op for op in ['gt', 'gte', 'lt', 'lte']}
def _gen_query(name, value, operator): def _to_es(key, values):
quantity = search_quantities.get(name) if key in bool_operators:
if quantity is None: if isinstance(values, dict):
raise InvalidQuery('Search quantity %s does not exist' % name) values = [values]
assert isinstance(values, list), 'bool operator requires a list of dicts or dict'
if operator in bool_operators: child_es_queries = [
value = value if isinstance(value, list) else [value] _to_es(child_key, child_value)
value = quantity.derived(value) if quantity.derived else value for child_query in values
q = [Q('match', **{quantity.search_field: item}) for item in value] for child_key, child_value in child_query.items()]
q = _add_queries(q, '$and') return Q('bool', **{bool_operators[key]: child_es_queries})
elif operator in comp_operators:
q = Q('range', **{quantity.search_field: {operator.lstrip('$'): value}}) if key in comp_operators:
else: assert isinstance(values, dict), 'comparison operator requires a dict'
raise ValueError('Invalid operator %s' % operator) assert len(values) == 1, 'comparison operator requires exactly one quantity'
quantity_name, value = next(iter(values.items()))
return q quantity = search_quantities.get(quantity_name)
assert quantity is not None, 'quantity %s does not exist' % quantity_name
def _add_queries(queries, operator): return Q('range', **{quantity.search_field: {comp_operators[key]: value}})
if operator == '$and':
q = Q('bool', must=queries)
elif operator == '$or':
q = Q('bool', should=queries)
elif operator == '$not':
q = Q('bool', must_not=queries)
elif operator in comp_operators:
q = Q('bool', must=queries)
elif operator is None:
q = queries[0]
else:
raise ValueError('Invalid operator %s' % operator)
return q
def _query(exp_value, exp_op=None): try:
if isinstance( return self._search_parameter_to_es(key, values)
exp_value, dict) and len(exp_value) == 1 and '$' not in list( except KeyError:
exp_value.keys())[-1]: assert False, 'quantity %s does not exist' % key
key, val = list(exp_value.items())[0]
query = _gen_query(key, val, exp_op)
else:
q = []
if isinstance(exp_value, dict):
for key, val in exp_value.items():
q.append(_query(val, exp_op=key))
elif isinstance(exp_value, list):
for val in exp_value:
op = exp_op
if isinstance(val, dict):
k, v = list(val.items())[0]
if k[0] == '$':
val, op = v, k
q.append(_query(val, exp_op=op))
query = _add_queries(q, exp_op)
return query if len(expression) == 0:
self.q &= Q()
else:
self.q &= Q('bool', must=[_to_es(key, value) for key, value in expression.items()])
q = _query(expression) return self
self.q &= q
def time_range(self, start: datetime, end: datetime): def time_range(self, start: datetime, end: datetime):
''' Adds a time range to the query. ''' ''' Adds a time range to the query. '''
......
...@@ -761,25 +761,32 @@ class TestArchive(UploadFilesBasedTests): ...@@ -761,25 +761,32 @@ class TestArchive(UploadFilesBasedTests):
return results return results
@pytest.mark.parametrize('query_expression, nresults', [ @pytest.mark.parametrize('query_expression, nresults', [
({ pytest.param({}, 4, id='empty'),
pytest.param({'dft.system': 'bulk'}, 4, id='match'),
pytest.param({'$gte': {'n_atoms': 1}}, 4, id='comparison'),
pytest.param({
'$and': [ '$and': [
{'dft.system': 'bulk'}, {'$not': [{'dft.compound_type': 'ternary'}]} {'dft.system': 'bulk'}, {'$not': [{'dft.compound_type': 'ternary'}]}
] ]
}, 2), }, 2, id="and-with-not"),
({ pytest.param({
'$or': [ '$or': [
{'upload_id': ['vasp_0']}, {'$gte': {'n_atoms': 1}} {'upload_id': ['vasp_0']}, {'$gte': {'n_atoms': 1}}
] ]
}, 4), }, 4, id="or-with-gte"),
({ pytest.param({
'$not': [{'dft.spacegroup': 221}, {'dft.spacegroup': 227}] '$not': [{'dft.spacegroup': 221}, {'dft.spacegroup': 227}]
}, 0), }, 0, id="not"),
pytest.param({
'$and': [
{'dft.code_name': 'VASP'},
{'$gte': {'n_atoms': 3}},
{'$lte': {'dft.workflow.section_relaxation.final_energy_difference': 1e-24}}
]}, 0, id='client-example')
]) ])
def test_post_archive_query(self, api, example_upload, query_expression, nresults): def test_post_archive_query(self, api, example_upload, query_expression, nresults):
data = {'pagination': {'per_page': 5}, 'query': query_expression} data = {'pagination': {'per_page': 5}, 'query': query_expression}
rv = api.post('/archive/query', content_type='application/json', data=json.dumps(data))
uri = '/archive/query'
rv = api.post(uri, content_type='application/json', data=json.dumps(data))
assert rv.status_code == 200 assert rv.status_code == 200
data = rv.get_json() data = rv.get_json()
...@@ -788,6 +795,16 @@ class TestArchive(UploadFilesBasedTests): ...@@ -788,6 +795,16 @@ class TestArchive(UploadFilesBasedTests):
results = data.get('results', None) results = data.get('results', None)
assert len(results) == nresults assert len(results) == nresults
@pytest.mark.parametrize('query', [
pytest.param({'$bad_op': {'n_atoms': 1}}, id='bad-op')
])
def test_post_archive_bad_query(self, api, query):
rv = api.post(
'/archive/query', content_type='application/json',
data=json.dumps(dict(query=query)))
assert rv.status_code == 400
class TestMetainfo(): class TestMetainfo():
@pytest.mark.parametrize('package', ['common', 'vasp']) @pytest.mark.parametrize('package', ['common', 'vasp'])
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment