Commit e01fe8c1 authored by Markus Scheidgen's avatar Markus Scheidgen
Browse files

Fixed issues with query expressions. Fixed issues with generated api code. #410

parent 92696cd7
Pipeline #81811 canceled with stages
in 13 minutes and 59 seconds
......@@ -284,7 +284,10 @@ class ArchiveQueryResource(Resource):
'calc_id', 'upload_id', 'with_embargo', 'published', 'parser_name')
if query_expression:
try:
search_request.query_expression(query_expression)
except AssertionError as e:
abort(400, str(e))
try:
if aggregation:
......
......@@ -342,7 +342,7 @@ def streamed_zipfile(
return response
def query_api_url(*args, query_string: Dict[str, Any] = None):
def _query_api_url(*args, query: Dict[str, Any] = None):
'''
Creates a API URL.
Arguments:
......@@ -350,43 +350,61 @@ def query_api_url(*args, query_string: Dict[str, Any] = None):
query_string: A dict with query string parameters
'''
url = os.path.join(config.api_url(False), *args)
if query_string is not None:
url = '%s?%s' % (url, urlencode(query_string, doseq=True))
if query is not None:
url = '%s?%s' % (url, urlencode(query, doseq=True))
return url
def query_api_python(*args, **kwargs):
def query_api_python(query):
'''
Creates a string of python code to execute a search query to the repository using
the requests library.
'''
url = query_api_url(*args, **kwargs)
query = _filter_api_query(query)
url = _query_api_url('archive', 'query')
return '''import requests
response = requests.post("{}")
data = response.json()'''.format(url)
response = requests.post('{}', json={{
'query': {{
{}
}}
}})
data = response.json()'''.format(
url,
',\n'.join([
' \'%s\': %s' % (key, pprint.pformat(value, compact=True))
for key, value in query.items()])
)
def query_api_clientlib(**kwargs):
'''
Creates a string of python code to execute a search query on the archive using
the client library.
'''
def _filter_api_query(query):
def normalize_value(key, value):
quantity = search.search_quantities.get(key)
if quantity.many and not isinstance(value, list):
return [value]
elif isinstance(value, list) and len(value) == 1:
return value[0]
return value
query = {
key: normalize_value(key, value) for key, value in kwargs.items()
result = {
key: normalize_value(key, value) for key, value in query.items()
if key in search.search_quantities and (key != 'domain' or value != config.meta.default_domain)
}
for key in ['dft.optimade']:
if key in kwargs:
query[key] = kwargs[key]
if key in query:
result[key] = query[key]
return result
def query_api_clientlib(**kwargs):
'''
Creates a string of python code to execute a search query on the archive using
the client library.
'''
query = _filter_api_query(kwargs)
out = io.StringIO()
out.write('from nomad import client, config\n')
......@@ -401,12 +419,13 @@ def query_api_clientlib(**kwargs):
return out.getvalue()
def query_api_curl(*args, **kwargs):
def query_api_curl(query):
'''
Creates a string of curl command to execute a search query to the repository.
Creates a string of curl command to execute a search query and download the respective
archives in a .zip file.
'''
url = query_api_url(*args, **kwargs)
return 'curl -X POST %s -H "accept: application/json" --output "nomad.json"' % url
url = _query_api_url('archive', 'download', query=query)
return 'curl "%s" --output nomad.zip' % url
def enable_gzip(level: int = 1, min_size: int = 1024):
......
......@@ -578,7 +578,10 @@ class RawFileQueryResource(Resource):
search_request.include('calc_id', 'upload_id', 'mainfile')
if query_expression:
try:
search_request.query_expression(query_expression)
except AssertionError as e:
abort(400, str(e))
def path(entry):
return '%s/%s' % (entry['upload_id'], entry['mainfile'])
......
......@@ -68,8 +68,8 @@ class RepoCalcResource(Resource):
result = calc.to_dict()
result['code'] = {
'python': query_api_python('archive', upload_id, calc_id),
'curl': query_api_curl('archive', upload_id, calc_id),
'python': query_api_python(dict(upload_id=upload_id, calc_id=calc_id)),
'curl': query_api_curl(dict(upload_id=upload_id, calc_id=calc_id)),
'clientlib': query_api_clientlib(upload_id=[upload_id], calc_id=[calc_id])
}
......@@ -256,12 +256,12 @@ class RepoCalcsResource(Resource):
results[group_name] = quantities[group_quantity.qualified_name]
# build python code/curl snippet
code_args = dict(request.args)
code_args = request.args.to_dict(flat=False)
if 'statistics' in code_args:
del(code_args['statistics'])
results['code'] = {
'curl': query_api_curl('archive', 'query', query_string=code_args),
'python': query_api_python('archive', 'query', query_string=code_args),
'curl': query_api_curl(code_args),
'python': query_api_python(code_args),
'clientlib': query_api_clientlib(**code_args)
}
......@@ -342,7 +342,10 @@ class RepoCalcsResource(Resource):
search_request.date_histogram(interval=interval, metrics_to_use=metrics)
if query_expression:
try:
search_request.query_expression(query_expression)
except AssertionError as e:
abort(400, str(e))
try:
assert page >= 1
......@@ -405,8 +408,8 @@ class RepoCalcsResource(Resource):
if 'statistics' in code_args:
del(code_args['statistics'])
results['code'] = {
'curl': query_api_curl('archive', 'query', query_string=code_args),
'python': query_api_python('archive', 'query', query_string=code_args),
'curl': query_api_curl(code_args),
'python': query_api_python(code_args),
'clientlib': query_api_clientlib(**code_args)
}
......
......@@ -215,15 +215,14 @@ class SearchRequest:
return self
def search_parameter(self, name, value):
def _search_parameter_to_es(self, name, value):
quantity = search_quantities[name]
if quantity.many and not isinstance(value, list):
value = [value]
if quantity.many_or and isinstance(value, List):
self.q &= Q('terms', **{quantity.search_field: value})
return self
return Q('terms', **{quantity.search_field: value})
if quantity.derived:
if quantity.many and not isinstance(value, list):
......@@ -235,9 +234,12 @@ class SearchRequest:
else:
values = [value]
for item in values:
self.q &= Q('match', **{quantity.search_field: item})
return Q('bool', must=[
Q('match', **{quantity.search_field: item})
for item in values])
def search_parameter(self, name, value):
self.q &= self._search_parameter_to_es(name, value)
return self
def query(self, query):
......@@ -246,74 +248,41 @@ class SearchRequest:
return self
def query_expression(self, expression):
bool_operators = ['$and', '$or', '$not']
comp_operators = ['$gt', '$lt', '$gte', '$lte']
def _gen_query(name, value, operator):
quantity = search_quantities.get(name)
if quantity is None:
raise InvalidQuery('Search quantity %s does not exist' % name)
if operator in bool_operators:
value = value if isinstance(value, list) else [value]
value = quantity.derived(value) if quantity.derived else value
q = [Q('match', **{quantity.search_field: item}) for item in value]
q = _add_queries(q, '$and')
elif operator in comp_operators:
q = Q('range', **{quantity.search_field: {operator.lstrip('$'): value}})
else:
raise ValueError('Invalid operator %s' % operator)
return q
def _add_queries(queries, operator):
if operator == '$and':
q = Q('bool', must=queries)
elif operator == '$or':
q = Q('bool', should=queries)
elif operator == '$not':
q = Q('bool', must_not=queries)
elif operator in comp_operators:
q = Q('bool', must=queries)
elif operator is None:
q = queries[0]
else:
raise ValueError('Invalid operator %s' % operator)
return q
def query_expression(self, expression) -> 'SearchRequest':
bool_operators = {'$and': 'must', '$or': 'should', '$not': 'must_not'}
comp_operators = {'$%s' % op: op for op in ['gt', 'gte', 'lt', 'lte']}
def _to_es(key, values):
if key in bool_operators:
if isinstance(values, dict):
values = [values]
assert isinstance(values, list), 'bool operator requires a list of dicts or dict'
child_es_queries = [
_to_es(child_key, child_value)
for child_query in values
for child_key, child_value in child_query.items()]
return Q('bool', **{bool_operators[key]: child_es_queries})
if key in comp_operators:
assert isinstance(values, dict), 'comparison operator requires a dict'
assert len(values) == 1, 'comparison operator requires exactly one quantity'
quantity_name, value = next(iter(values.items()))
quantity = search_quantities.get(quantity_name)
assert quantity is not None, 'quantity %s does not exist' % quantity_name
return Q('range', **{quantity.search_field: {comp_operators[key]: value}})
def _query(exp_value, exp_op=None):
if isinstance(
exp_value, dict) and len(exp_value) == 1 and '$' not in list(
exp_value.keys())[-1]:
key, val = list(exp_value.items())[0]
query = _gen_query(key, val, exp_op)
try:
return self._search_parameter_to_es(key, values)
except KeyError:
assert False, 'quantity %s does not exist' % key
if len(expression) == 0:
self.q &= Q()
else:
q = []
if isinstance(exp_value, dict):
for key, val in exp_value.items():
q.append(_query(val, exp_op=key))
elif isinstance(exp_value, list):
for val in exp_value:
op = exp_op
if isinstance(val, dict):
k, v = list(val.items())[0]
if k[0] == '$':
val, op = v, k
self.q &= Q('bool', must=[_to_es(key, value) for key, value in expression.items()])
q.append(_query(val, exp_op=op))
query = _add_queries(q, exp_op)
return query
q = _query(expression)
self.q &= q
return self
def time_range(self, start: datetime, end: datetime):
''' Adds a time range to the query. '''
......
......@@ -761,25 +761,32 @@ class TestArchive(UploadFilesBasedTests):
return results
@pytest.mark.parametrize('query_expression, nresults', [
({
pytest.param({}, 4, id='empty'),
pytest.param({'dft.system': 'bulk'}, 4, id='match'),
pytest.param({'$gte': {'n_atoms': 1}}, 4, id='comparison'),
pytest.param({
'$and': [
{'dft.system': 'bulk'}, {'$not': [{'dft.compound_type': 'ternary'}]}
]
}, 2),
({
}, 2, id="and-with-not"),
pytest.param({
'$or': [
{'upload_id': ['vasp_0']}, {'$gte': {'n_atoms': 1}}
]
}, 4),
({
}, 4, id="or-with-gte"),
pytest.param({
'$not': [{'dft.spacegroup': 221}, {'dft.spacegroup': 227}]
}, 0),
}, 0, id="not"),
pytest.param({
'$and': [
{'dft.code_name': 'VASP'},
{'$gte': {'n_atoms': 3}},
{'$lte': {'dft.workflow.section_relaxation.final_energy_difference': 1e-24}}
]}, 0, id='client-example')
])
def test_post_archive_query(self, api, example_upload, query_expression, nresults):
data = {'pagination': {'per_page': 5}, 'query': query_expression}
uri = '/archive/query'
rv = api.post(uri, content_type='application/json', data=json.dumps(data))
rv = api.post('/archive/query', content_type='application/json', data=json.dumps(data))
assert rv.status_code == 200
data = rv.get_json()
......@@ -788,6 +795,16 @@ class TestArchive(UploadFilesBasedTests):
results = data.get('results', None)
assert len(results) == nresults
@pytest.mark.parametrize('query', [
pytest.param({'$bad_op': {'n_atoms': 1}}, id='bad-op')
])
def test_post_archive_bad_query(self, api, query):
rv = api.post(
'/archive/query', content_type='application/json',
data=json.dumps(dict(query=query)))
assert rv.status_code == 400
class TestMetainfo():
@pytest.mark.parametrize('package', ['common', 'vasp'])
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment