Commit 5baba261 authored by Markus Scheidgen's avatar Markus Scheidgen
Browse files

Implemented search in repo search endpoint. Added search for files, path, quantities, etc.

parent b75b8896
Pipeline #43896 passed with stages
in 22 minutes and 54 seconds
...@@ -22,9 +22,9 @@ from .app import api ...@@ -22,9 +22,9 @@ from .app import api
pagination_model = api.model('Pagination', { pagination_model = api.model('Pagination', {
'total': fields.Integer, 'total': fields.Integer(description='Number of total elements.'),
'page': fields.Integer, 'page': fields.Integer(description='Number of the current page, starting with 0.'),
'per_page': fields.Integer, 'per_page': fields.Integer(description='Number of elements per page.'),
}) })
""" Model used in responsed with pagination. """ """ Model used in responsed with pagination. """
......
...@@ -22,7 +22,7 @@ from flask import request, g ...@@ -22,7 +22,7 @@ from flask import request, g
from elasticsearch_dsl import Q from elasticsearch_dsl import Q
from nomad.files import UploadFiles, Restricted from nomad.files import UploadFiles, Restricted
from nomad.search import Entry from nomad import search
from .app import api from .app import api
from .auth import login_if_available, create_authorization_predicate from .auth import login_if_available, create_authorization_predicate
...@@ -61,7 +61,12 @@ class RepoCalcResource(Resource): ...@@ -61,7 +61,12 @@ class RepoCalcResource(Resource):
repo_calcs_model = api.model('RepoCalculations', { repo_calcs_model = api.model('RepoCalculations', {
'pagination': fields.Nested(pagination_model), 'pagination': fields.Nested(pagination_model),
'results': fields.List(fields.Raw) 'results': fields.List(fields.Raw, description=(
'A list of search results. Each result is a dict with quantitie names as key and '
'values as values')),
'aggregations': fields.Raw(description=(
'A dict with all aggregations. Each aggregation is dictionary with the amount as '
'value and quantity value as key.'))
}) })
repo_request_parser = pagination_request_parser.copy() repo_request_parser = pagination_request_parser.copy()
...@@ -69,28 +74,35 @@ repo_request_parser.add_argument( ...@@ -69,28 +74,35 @@ repo_request_parser.add_argument(
'owner', type=str, 'owner', type=str,
help='Specify which calcs to return: ``all``, ``user``, ``staging``, default is ``all``') help='Specify which calcs to return: ``all``, ``user``, ``staging``, default is ``all``')
for search_quantity in search.search_quantities.keys():
_, _, description = search.search_quantities[search_quantity]
repo_request_parser.add_argument(search_quantity, type=str, help=description)
@ns.route('/') @ns.route('/')
class RepoCalcsResource(Resource): class RepoCalcsResource(Resource):
@api.doc('get_calcs') @api.doc('get_calcs')
@api.response(400, 'Invalid requests, e.g. wrong owner type') @api.response(400, 'Invalid requests, e.g. wrong owner type or bad quantities')
@api.expect(repo_request_parser, validate=True) @api.expect(repo_request_parser, validate=True)
@api.marshal_with(repo_calcs_model, skip_none=True, code=200, description='Metadata send') @api.marshal_with(repo_calcs_model, skip_none=True, code=200, description='Metadata send')
@login_if_available @login_if_available
def get(self): def get(self):
""" """
Get *'all'* calculations in repository from, paginated. Search for calculations in the repository from, paginated.
This is currently not implemented! The ``owner`` parameter determines the overall entries to search through.
You can use the various quantities to search/filter for. For some of the
indexed quantities this endpoint returns aggregation information. This means
you will be given a list of all possible values and the number of entries
that have the certain value. You can also use these aggregations on an empty
search to determine the possible values.
""" """
# return dict(pagination=dict(total=0, page=1, per_page=10), results=[]), 200 page = int(request.args.get('page', 0))
page = int(request.args.get('page', 1))
per_page = int(request.args.get('per_page', 10)) per_page = int(request.args.get('per_page', 10))
owner = request.args.get('owner', 'all') owner = request.args.get('owner', 'all')
try: try:
assert page >= 1 assert page >= 0
assert per_page > 0 assert per_page > 0
except AssertionError: except AssertionError:
abort(400, message='invalid pagination') abort(400, message='invalid pagination')
...@@ -112,13 +124,18 @@ class RepoCalcsResource(Resource): ...@@ -112,13 +124,18 @@ class RepoCalcsResource(Resource):
else: else:
abort(400, message='Invalid owner value. Valid values are all|user|staging, default is all') abort(400, message='Invalid owner value. Valid values are all|user|staging, default is all')
search = Entry.search().query(q) data = dict(**request.args)
search = search[(page - 1) * per_page: page * per_page] data.pop('owner', None)
return { data.pop('page', None)
'pagination': { data.pop('per_page', None)
'total': search.count(),
'page': page, try:
'per_page': per_page total, results, aggregations = search.aggregate_search(
}, page=page, per_page=per_page, q=q, **data)
'results': [hit.to_dict() for hit in search] except KeyError as e:
}, 200 abort(400, str(e))
return dict(
pagination=dict(total=total, page=page, per_page=per_page),
results=results,
aggregations=aggregations), 200
...@@ -18,12 +18,16 @@ This module represents calculations in elastic search. ...@@ -18,12 +18,16 @@ This module represents calculations in elastic search.
from typing import Iterable, Dict, Tuple, List from typing import Iterable, Dict, Tuple, List
from elasticsearch_dsl import Document, InnerDoc, Keyword, Text, Date, \ from elasticsearch_dsl import Document, InnerDoc, Keyword, Text, Date, \
Object, Boolean, Search, Integer, Q, A Object, Boolean, Search, Integer, Q, A, analyzer, tokenizer
import elasticsearch.helpers import elasticsearch.helpers
import ase.data import ase.data
from nomad import config, datamodel, infrastructure, datamodel, coe_repo, parsing from nomad import config, datamodel, infrastructure, datamodel, coe_repo, parsing
path_analyzer = analyzer(
'path_analyzer',
tokenizer=tokenizer('path_tokenizer', 'pattern', pattern='/'))
class AlreadyExists(Exception): pass class AlreadyExists(Exception): pass
...@@ -39,13 +43,11 @@ class User(InnerDoc): ...@@ -39,13 +43,11 @@ class User(InnerDoc):
name = '%s, %s' % (user['last_name'], user['first_name']) name = '%s, %s' % (user['last_name'], user['first_name'])
self.name = name self.name = name
self.name_keyword = name
return self return self
user_id = Keyword() user_id = Keyword()
name = Text() name = Text(fields={'keyword': Keyword()})
name_keyword = Keyword()
class Dataset(InnerDoc): class Dataset(InnerDoc):
...@@ -72,7 +74,7 @@ class Entry(Document): ...@@ -72,7 +74,7 @@ class Entry(Document):
calc_hash = Keyword() calc_hash = Keyword()
pid = Keyword() pid = Keyword()
mainfile = Keyword() mainfile = Keyword()
files = Keyword(multi=True) files = Text(multi=True, analyzer=path_analyzer, fields={'keyword': Keyword()})
uploader = Object(User) uploader = Object(User)
with_embargo = Boolean() with_embargo = Boolean()
...@@ -99,11 +101,6 @@ class Entry(Document): ...@@ -99,11 +101,6 @@ class Entry(Document):
geometries = Keyword(multi=True) geometries = Keyword(multi=True)
quantities = Keyword(multi=True) quantities = Keyword(multi=True)
# def __init__(self, *args, **kwargs):
# super().__init__(*args, **kwargs)
# self.authors = []
# self.owners = []
@classmethod @classmethod
def from_calc_with_metadata(cls, source: datamodel.CalcWithMetadata) -> 'Entry': def from_calc_with_metadata(cls, source: datamodel.CalcWithMetadata) -> 'Entry':
entry = Entry(meta=dict(id=source.calc_id)) entry = Entry(meta=dict(id=source.calc_id))
...@@ -116,8 +113,15 @@ class Entry(Document): ...@@ -116,8 +113,15 @@ class Entry(Document):
self.calc_id = source.calc_id self.calc_id = source.calc_id
self.calc_hash = source.calc_hash self.calc_hash = source.calc_hash
self.pid = str(source.pid) self.pid = str(source.pid)
self.mainfile = source.mainfile self.mainfile = source.mainfile
self.files = source.files if source.files is None:
self.files = [self.mainfile]
elif self.mainfile not in source.files:
self.files = [self.mainfile] + source.files
else:
self.files = source.files
self.uploader = User.from_user_popo(source.uploader) if source.uploader is not None else None self.uploader = User.from_user_popo(source.uploader) if source.uploader is not None else None
self.with_embargo = source.with_embargo self.with_embargo = source.with_embargo
...@@ -179,7 +183,7 @@ def publish(calcs: Iterable[datamodel.CalcWithMetadata]) -> None: ...@@ -179,7 +183,7 @@ def publish(calcs: Iterable[datamodel.CalcWithMetadata]) -> None:
elasticsearch.helpers.bulk(infrastructure.elastic_client, elastic_updates()) elasticsearch.helpers.bulk(infrastructure.elastic_client, elastic_updates())
default_aggregations = { aggregations = {
'atoms': len(ase.data.chemical_symbols), 'atoms': len(ase.data.chemical_symbols),
'system': 10, 'system': 10,
'crystal_system': 10, 'crystal_system': 10,
...@@ -187,12 +191,37 @@ default_aggregations = { ...@@ -187,12 +191,37 @@ default_aggregations = {
'xc_functional': 10, 'xc_functional': 10,
'authors': 10 'authors': 10
} }
""" The available aggregations in :func:`aggregate_search` and their maximum aggregation size """
search_quantities = {
'atoms': ('term', 'atoms', (
'Search the given atom. This quantity can be used multiple times to search for '
'results with all the given atoms. The atoms are given by their case sensitive '
'symbol, e.g. Fe.')),
'system': ('term', 'system', 'Search for the given system type.'),
'crystal_system': ('term', 'crystal_system', 'Search for the given crystal system.'),
'code_name': ('term', 'code_name', 'Search for the given code name.'),
'xc_functional': ('term', 'xc_functional', 'Search for the given xc functional treatment'),
'authors': ('term', 'authors.name.keyword', (
'Search for the given author. Exact keyword matches in the form "Lastname, Firstname".')),
'comment': ('match', 'comment', 'Search within the comments. This is a text search ala google.'),
'paths': ('match', 'files', (
'Search for elements in one of the file paths. The paths are split at all "/".')),
'files': ('term', 'files.keyword', 'Search for exact file name with full path.'),
'quantities': ('term', 'quantities', 'Search for the existence of a certain meta-info quantity')
}
"""
The available search quantities in :func:`aggregate_search` as tuples with *search type*,
elastic field and description.
"""
def aggregate_search( def aggregate_search(
page: int = 0, per_page: int = 10, q: Q = None, page: int = 0, per_page: int = 10, q: Q = None, **kwargs) -> Tuple[int, List[dict], Dict[str, Dict[str, int]]]:
aggregations: Dict[str, int] = default_aggregations,
**kwargs) -> Tuple[int, List[dict], Dict[str, Dict[str, int]]]:
""" """
Performs a search and returns paginated search results and aggregation bucket sizes Performs a search and returns paginated search results and aggregation bucket sizes
based on key quantities. based on key quantities.
...@@ -203,7 +232,7 @@ def aggregate_search( ...@@ -203,7 +232,7 @@ def aggregate_search(
q: An *elasticsearch_dsl* query used to further filter the results (via `and`) q: An *elasticsearch_dsl* query used to further filter the results (via `and`)
aggregations: A customized list of aggregations to perform. Keys are index fields, aggregations: A customized list of aggregations to perform. Keys are index fields,
and values the amount of buckets to return. Only works on *keyword* field. and values the amount of buckets to return. Only works on *keyword* field.
**kwargs: Field, value pairs to search for. **kwargs: Quantity, value pairs to search for.
Returns: A tuple with the total hits, an array with the results, an dictionary with Returns: A tuple with the total hits, an array with the results, an dictionary with
the aggregation data. the aggregation data.
...@@ -211,12 +240,18 @@ def aggregate_search( ...@@ -211,12 +240,18 @@ def aggregate_search(
search = Search() search = Search()
if q is not None: if q is not None:
search.query(q) search = search.query(q)
for key, value in kwargs.items(): for key, value in kwargs.items():
if key == 'comment': query_type, field, _ = search_quantities.get(key, (None, None, None))
search = search.query(Q('match', **{key: value})) if query_type is None:
raise KeyError('Unknown quantity %s' % key)
if isinstance(value, list):
for item in value:
search = search.query(Q(query_type, **{field: item}))
else: else:
search = search.query(Q('term', **{key: value})) search = search.query(Q(query_type, **{field: value}))
for aggregation, size in aggregations.items(): for aggregation, size in aggregations.items():
if aggregation == 'authors': if aggregation == 'authors':
...@@ -261,7 +296,7 @@ def authors(per_page: int = 10, after: str = None, prefix: str = None) -> Tuple[ ...@@ -261,7 +296,7 @@ def authors(per_page: int = 10, after: str = None, prefix: str = None) -> Tuple[
""" """
composite = dict( composite = dict(
size=per_page, size=per_page,
sources=dict(authors=dict(terms=dict(field='authors.name_keyword')))) sources=dict(authors=dict(terms=dict(field='authors.name.keyword'))))
if after is not None: if after is not None:
composite.update(after=dict(authors=after)) composite.update(after=dict(authors=after))
......
...@@ -545,6 +545,7 @@ class TestRepo(UploadFilesBasedTests): ...@@ -545,6 +545,7 @@ class TestRepo(UploadFilesBasedTests):
search.Entry.from_calc_with_metadata(calc_with_metadata).save(refresh=True) search.Entry.from_calc_with_metadata(calc_with_metadata).save(refresh=True)
calc_with_metadata.update(calc_id='2', uploader=other_test_user.to_popo(), published=True) calc_with_metadata.update(calc_id='2', uploader=other_test_user.to_popo(), published=True)
calc_with_metadata.update(atoms=['Fe'], comment='this is a specific word')
search.Entry.from_calc_with_metadata(calc_with_metadata).save(refresh=True) search.Entry.from_calc_with_metadata(calc_with_metadata).save(refresh=True)
calc_with_metadata.update(calc_id='3', uploader=other_test_user.to_popo(), published=False) calc_with_metadata.update(calc_id='3', uploader=other_test_user.to_popo(), published=False)
...@@ -566,9 +567,9 @@ class TestRepo(UploadFilesBasedTests): ...@@ -566,9 +567,9 @@ class TestRepo(UploadFilesBasedTests):
(1, 'user', 'test_user'), (1, 'user', 'test_user'),
(2, 'user', 'other_test_user'), (2, 'user', 'other_test_user'),
(0, 'staging', 'test_user'), (0, 'staging', 'test_user'),
(1, 'staging', 'other_test_user'), (1, 'staging', 'other_test_user')
]) ])
def test_search(self, client, example_elastic_calcs, no_warn, test_user_auth, other_test_user_auth, calcs, owner, auth): def test_search_owner(self, client, example_elastic_calcs, no_warn, test_user_auth, other_test_user_auth, calcs, owner, auth):
auth = dict(none=None, test_user=test_user_auth, other_test_user=other_test_user_auth).get(auth) auth = dict(none=None, test_user=test_user_auth, other_test_user=other_test_user_auth).get(auth)
rv = client.get('/repo/?owner=%s' % owner, headers=auth) rv = client.get('/repo/?owner=%s' % owner, headers=auth)
assert rv.status_code == 200 assert rv.status_code == 200
...@@ -581,7 +582,45 @@ class TestRepo(UploadFilesBasedTests): ...@@ -581,7 +582,45 @@ class TestRepo(UploadFilesBasedTests):
for key in ['uploader', 'calc_id', 'formula', 'upload_id']: for key in ['uploader', 'calc_id', 'formula', 'upload_id']:
assert key in results[0] assert key in results[0]
def test_calcs_pagination(self, client, example_elastic_calcs, no_warn): @pytest.mark.parametrize('calcs, quantity, value', [
(2, 'system', 'Bulk'),
(0, 'system', 'Atom'),
(1, 'atoms', 'Br'),
(1, 'atoms', 'Fe'),
(0, 'atoms', ['Fe', 'Br']),
(1, 'comment', 'specific'),
(1, 'authors', 'Hofstadter, Leonard'),
(2, 'files', 'test/mainfile.txt'),
(2, 'paths', 'mainfile.txt'),
(2, 'paths', 'test'),
(2, 'quantities', ['wyckoff_letters_primitive', 'hall_number']),
(0, 'quantities', 'dos')
])
def test_search_quantities(self, client, example_elastic_calcs, no_warn, test_user_auth, calcs, quantity, value):
if isinstance(value, list):
query_string = '&'.join('%s=%s' % (quantity, item) for item in value)
else:
query_string = '%s=%s' % (quantity, value)
rv = client.get('/repo/?%s' % query_string, headers=test_user_auth)
assert rv.status_code == 200
data = json.loads(rv.data)
results = data.get('results', None)
assert results is not None
assert isinstance(results, list)
assert len(results) == calcs
aggregations = data.get('aggregations', None)
assert aggregations is not None
if quantity == 'system' and calcs != 0:
# for simplicity we only assert on aggregations for this case
assert 'system' in aggregations
assert len(aggregations['system']) == 1
assert value in aggregations['system']
def test_search_pagination(self, client, example_elastic_calcs, no_warn):
rv = client.get('/repo/?page=1&per_page=1') rv = client.get('/repo/?page=1&per_page=1')
assert rv.status_code == 200 assert rv.status_code == 200
data = json.loads(rv.data) data = json.loads(rv.data)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment