search.py 39 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# Copyright The NOMAD Authors.
#
# This file is part of NOMAD. See https://nomad-lab.eu for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
'''
This module provides an interface to elasticsearch. Other parts of NOMAD must not
interact with elasticsearch to maintain a clear coherent interface and allow for change.

Currently NOMAD uses one entry index and two distinct materials indices. The entries
index is based on two different mappings, once used by the old flask api (v0) and one
used by the new fastapi api (v1). The mappings are used at the same time and the documents
are merged. Write operations (index, publish, edit, lift embargo, delete) are common; defined
here in the module ``__init__.py``. Read operations are different and
should be used as per use-case directly from the ``v0`` and ``v1`` submodules.

Most common functions also take an ``update_materials`` keyword arg with allows to
update the v1 materials index according to the performed changes. TODO this is only
partially implemented.
'''

from typing import Union, List, Iterable, Any, cast, Dict, Generator
import json
import elasticsearch
from elasticsearch.exceptions import TransportError, RequestError
from elasticsearch_dsl import Q, A, Search
40
41
from elasticsearch_dsl.query import Query as EsQuery
from pydantic.error_wrappers import ErrorWrapper
42

43
44
45
from nomad import config, infrastructure, utils
from nomad import datamodel
from nomad.datamodel import EntryArchive, EntryMetadata
46
from nomad.app.v1 import models as api_models
47
from nomad.app.v1.models import (
48
49
50
    AggregationPagination, MetadataPagination, Pagination, PaginationResponse,
    QuantityAggregation, Query, MetadataRequired,
    MetadataResponse, Aggregation, StatisticsAggregation, StatisticsAggregationResponse,
51
52
    Value, AggregationBase, TermsAggregation, BucketAggregation, HistogramAggregation,
    DateHistogramAggregation, MinMaxAggregation, Bucket,
53
    MinMaxAggregationResponse, TermsAggregationResponse, HistogramAggregationResponse,
54
    DateHistogramAggregationResponse, AggregationResponse)
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
from nomad.metainfo.elasticsearch_extension import (
    index_entries, entry_type, entry_index, DocumentType,
    material_type, entry_type, material_entry_type,
    entry_index, Index, index_entries, DocumentType, SearchQuantity)


def update_by_query(
        update_script: str,
        query: Any = None,
        owner: str = None,
        user_id: str = None,
        index: str = None,
        refresh: bool = False,
        **kwargs):
    '''
    Uses the given painless script to update the entries by given query.

    In most cases, the elasticsearch entry index should not be updated field by field;
    you should run `index` instead and fully replace documents from mongodb and
    archive files.

    This method provides a faster direct method to update individual fields, e.g. to quickly
    update fields for editing operations.
    '''
    if query is None:
        query = {}
    es_query = _api_to_es_query(query)
    if owner is not None:
        es_query &= _owner_es_query(owner=owner, user_id=user_id)

    body = {
        'script': {
            'source': update_script,
            'lang': 'painless'
        },
        'query': es_query.to_dict()
    }

    body['script'].update(**kwargs)

    try:
        result = infrastructure.elastic_client.update_by_query(
            body=body, index=config.elastic.entries_index)
    except TransportError as e:
        utils.get_logger(__name__).error(
            'es update_by_query script error', exc_info=e,
            es_info=json.dumps(e.info, indent=2))
        raise SearchError(e)

    if refresh:
        _refresh()

    return result


def delete_by_query(
        query: dict,
        owner: str = None,
        user_id: str = None,
        update_materials: bool = False,
        refresh: bool = False):
    '''
    Deletes all entries that match the given query.
    '''
    if query is None:
        query = {}
    es_query = _api_to_es_query(query)
    es_query &= _owner_es_query(owner=owner, user_id=user_id)

    body = {
        'query': es_query.to_dict()
    }

    try:
        result = infrastructure.elastic_client.delete_by_query(
            body=body, index=config.elastic.entries_index)
    except TransportError as e:
        utils.get_logger(__name__).error(
            'es delete_by_query error', exc_info=e,
            es_info=json.dumps(e.info, indent=2))
        raise SearchError(e)

    if refresh:
        _refresh()

    if update_materials:
        # TODO update the matrials index at least for v1
        pass

    return result


def refresh():
    '''
    Refreshes the specified indices.
    '''

    try:
        infrastructure.elastic_client.indices.refresh(index=config.elastic.entries_index)
    except TransportError as e:
        utils.get_logger(__name__).error(
            'es delete_by_query error', exc_info=e,
            es_info=json.dumps(e.info, indent=2))
        raise SearchError(e)


_refresh = refresh


def index(
        entries: Union[EntryArchive, List[EntryArchive]],
        update_materials: bool = False,
        refresh: bool = True):
    '''
    Index the given entries based on their archive. Either creates or updates the underlying
    elasticsearch documents. If an underlying elasticsearch document already exists it
    will be fully replaced.
    '''
    if not isinstance(entries, list):
        entries = [entries]

    index_entries(entries=entries, update_materials=update_materials)

    if refresh:
        _refresh()


# TODO this depends on how we merge section metadata
def publish(entries: Iterable[EntryMetadata], index: str = None) -> int:
    '''
    Publishes the given entries based on their entry metadata. Sets publishes to true,
    and updates most user provided metadata with a partial update. Returns the number
    of failed updates.
    '''
    return update_metadata(
        entries, index=index, published=True, update_materials=True, refresh=True)


def update_metadata(
        entries: Iterable[EntryMetadata], index: str = None,
        update_materials: bool = False, refresh: bool = False,
        **kwargs) -> int:
    '''
    Update all given entries with their given metadata. Additionally apply kwargs.
    Returns the number of failed updates. This is doing a partial update on the underlying
    elasticsearch documents.
    '''

    def elastic_updates():
        for entry_metadata in entries:
            entry_archive = entry_metadata.m_parent
            if entry_archive is None:
                entry_archive = EntryArchive(metadata=entry_metadata)
            entry_doc = entry_type.create_index_doc(entry_archive)

            entry_doc.update(**kwargs)
            # TODO this a exception that should be treated differently. None values are
            # not included in elasticsearch docs. However, when a user removes his comment
            # (the only case where a value is unset in an update?), the None value needs
            # to be transported.
            if 'comment' not in entry_doc:
                entry_doc['comment'] = None

            yield dict(
                doc=entry_doc,
                _id=entry_metadata.calc_id,
                _type=entry_index.doc_type.name,
                _index=entry_index.index_name,
                _op_type='update')

    updates = list(elastic_updates())
    _, failed = elasticsearch.helpers.bulk(
        infrastructure.elastic_client, updates, stats_only=True)

    if update_materials:
        # TODO update the matrials index at least for v1
        pass

    if refresh:
        _refresh()

    return failed


def delete_upload(upload_id: str, refresh: bool = False, **kwargs):
    '''
    Deletes the given upload.
    '''
    delete_by_query(query=dict(upload_id=upload_id), **kwargs)

    if refresh:
        _refresh()


def delete_entry(entry_id: str, index: str = None, refresh: bool = False, **kwargs):
    '''
    Deletes the given entry.
    '''
    delete_by_query(query=dict(calc_id=entry_id), **kwargs)

    if refresh:
        _refresh()


class SearchError(Exception): pass


class AuthenticationRequiredError(Exception): pass


_entry_metadata_defaults = {
    quantity.name: quantity.default
    for quantity in datamodel.EntryMetadata.m_def.quantities  # pylint: disable=not-an-iterable
    if quantity.default not in [None, [], False, 0]
}


def _es_to_entry_dict(hit, required: MetadataRequired = None) -> Dict[str, Any]:
    '''
    Elasticsearch entry metadata does not contain default values, if a metadata is not
    set. This will add default values to entry metadata in dict form obtained from
    elasticsearch.
    '''
    entry_dict = hit.to_dict()
    for key, value in _entry_metadata_defaults.items():
        if key not in entry_dict:
            if required is not None:
                if required.exclude and key in required.exclude:
                    continue
                if required.include and key not in required.include:
                    continue

            entry_dict[key] = value
288

289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
    return entry_dict


def _api_to_es_query(query: api_models.Query) -> Q:
    '''
    Creates an ES query based on the API's query model. This needs to be a normalized
    query expression with explicit objects for logical, set, and comparison operators.
    Shorthand notations ala ``quantity:operator`` are not supported here; this
    needs to be resolved via the respective pydantic validator. There is also no
    validation of quantities and types.
    '''
    def quantity_to_es(name: str, value: api_models.Value) -> Q:
        # TODO depends on keyword or not, value might need normalization, etc.
        quantity = entry_type.quantities[name]
        return Q('match', **{quantity.search_field: value})

    def parameter_to_es(name: str, value: api_models.QueryParameterValue) -> Q:

        if isinstance(value, api_models.All):
            return Q('bool', must=[
                quantity_to_es(name, item)
                for item in value.op])

        if isinstance(value, api_models.Any_):
            return Q('bool', should=[
                quantity_to_es(name, item)
                for item in value.op])

        if isinstance(value, api_models.None_):
            return Q('bool', must_not=[
                quantity_to_es(name, item)
                for item in value.op])

        if isinstance(value, api_models.Range):
            quantity = entry_type.quantities[name]
            return Q('range', **{quantity.search_field: value.dict(
                exclude_unset=True,
            )})

        # list of values is treated as an "all" over the items
        if isinstance(value, list):
            return Q('bool', must=[
                quantity_to_es(name, item)
                for item in value])

        return quantity_to_es(name, cast(api_models.Value, value))

    def query_to_es(query: api_models.Query) -> Q:
        if isinstance(query, api_models.LogicalOperator):
            if isinstance(query, api_models.And):
                return Q('bool', must=[query_to_es(operand) for operand in query.op])

            if isinstance(query, api_models.Or):
                return Q('bool', should=[query_to_es(operand) for operand in query.op])

            if isinstance(query, api_models.Not):
                return Q('bool', must_not=query_to_es(query.op))

            raise NotImplementedError()

        if not isinstance(query, dict):
            raise NotImplementedError()

        # dictionary is like an "and" of all items in the dict
        if len(query) == 0:
            return Q()

        if len(query) == 1:
            key = next(iter(query))
            return parameter_to_es(key, query[key])

        return Q('bool', must=[
            parameter_to_es(name, value) for name, value in query.items()])

    return query_to_es(query)


def _owner_es_query(owner: str, user_id: str = None, doc_type: DocumentType = entry_type):
    def term_query(**kwargs):
        prefix = '' if doc_type == entry_type else 'entries.'
        return Q('term', **{
            (prefix + field): value for field, value in kwargs.items()})

    if owner == 'all':
        q = term_query(published=True)
        if user_id is not None:
            q = q | term_query(owners__user_id=user_id)
    elif owner == 'public':
        q = term_query(published=True) & term_query(with_embargo=False)
    elif owner == 'visible':
        q = term_query(published=True) & term_query(with_embargo=False)
        if user_id is not None:
            q = q | term_query(owners__user_id=user_id)
    elif owner == 'shared':
        if user_id is None:
            raise AuthenticationRequiredError('Authentication required for owner value shared.')

        q = term_query(owners__user_id=user_id)
    elif owner == 'user':
        if user_id is None:
            raise AuthenticationRequiredError('Authentication required for owner value user.')

        q = term_query(uploader__user_id=user_id)
    elif owner == 'staging':
        if user_id is None:
            raise AuthenticationRequiredError('Authentication required for owner value user')
        q = term_query(published=False) & term_query(owners__user_id=user_id)
    elif owner == 'admin':
        if user_id is None or not datamodel.User.get(user_id=user_id).is_admin:
            raise AuthenticationRequiredError('This can only be used by the admin user.')
        q = None
    elif owner is None:
        q = None
    else:
        raise KeyError('Unsupported owner value')

    if q is not None:
        return q
    return Q()
408
409


410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
class QueryValidationError(Exception):
    def __init__(self, error, loc):
        self.errors = [ErrorWrapper(Exception(error), loc=loc)]


def validate_quantity(
        quantity_name: str, value: Value = None, doc_type: DocumentType = None,
        loc: List[str] = None) -> SearchQuantity:
    '''
    Validates the given quantity name and value against the given document type.

    Returns:
        A metainfo elasticsearch extension SearchQuantity object.

    Raises: QueryValidationError
    '''
    assert quantity_name is not None

428
429
430
    if doc_type == material_entry_type and not quantity_name.startswith('entries'):
        quantity_name = f'entries.{quantity_name}'

431
432
433
    if doc_type == material_type and quantity_name.startswith('entries'):
        doc_type = material_entry_type

434
435
436
437
438
439
440
441
442
443
444
445
    if doc_type is None:
        doc_type = entry_type

    quantity = doc_type.quantities.get(quantity_name)
    if quantity is None:
        raise QueryValidationError(
            f'{quantity_name} is not a {doc_type} quantity',
            loc=[quantity_name] if loc is None else loc)

    return quantity


446
447
448
def validate_api_query(
        query: Query, doc_type: DocumentType, owner_query: EsQuery,
        prefix: str = None) -> EsQuery:
449
450
451
452
453
454
455
456
457
458
459
460
    '''
    Creates an ES query based on the API's query model. This needs to be a normalized
    query expression with explicit objects for logical, set, and comparison operators.
    Shorthand notations ala ``quantity:operator`` are not supported here; this
    needs to be resolved via the respective pydantic validator.

    However, this function performs validation of quantities and types and raises
    a QueryValidationError accordingly. This exception is populated with pydantic
    errors.

    Arguments:
        query: The api query object.
461
462
463
464
465
466
467
468
        doc_type:
            The elasticsearch metainfo extension document type that this query needs to
            be verified against.
        owner_query:
            A prebuild ES query that is added to nested entries query. Only for
            materials queries.
        prefix:
            An optional prefix that is added to all quantity names. Used for recursion.
469
470
471
472
473
474

    Returns:
        A elasticsearch dsl query object.

    Raises: QueryValidationError
    '''
475
476

    def match(name: str, value: Value) -> EsQuery:
477
478
        if name == 'optimade_filter':
            value = str(value)
479
            from nomad.app.optimade import filterparser
480
481
482
483
484
485
486
487
488
            try:
                return filterparser.parse_filter(
                    value, nomad_properties='dft', without_prefix=True)

            except filterparser.FilterException as e:
                raise QueryValidationError(
                    f'Could not parse optimade filter: {e}',
                    loc=[name])

489
        # TODO non keyword quantities, quantities with value transformation, type checks
490
491
492
        quantity = validate_quantity(name, value, doc_type=doc_type)
        return Q('match', **{quantity.search_field: value})

493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
    def validate_query(query: Query) -> EsQuery:
        return validate_api_query(
            query, doc_type=doc_type, owner_query=owner_query, prefix=prefix)

    def validate_criteria(name: str, value: Any):
        if prefix is not None:
            name = f'{prefix}.{name}'

        # handle prefix and nested queries
        for nested_key in doc_type.nested_object_keys:
            if len(name) < len(nested_key):
                break
            if not name.startswith(nested_key):
                continue
            if prefix is not None and prefix.startswith(nested_key):
                continue
            if nested_key == name and isinstance(value, api_models.Nested):
                continue

            value = api_models.Nested(query={name[len(nested_key) + 1:]: value})
            name = nested_key
            break
515
516

        if isinstance(value, api_models.All):
517
            return Q('bool', must=[match(name, item) for item in value.op])
518

519
        elif isinstance(value, api_models.Any_):
520
            return Q('bool', should=[match(name, item) for item in value.op])
521

522
        elif isinstance(value, api_models.None_):
523
            return Q('bool', must_not=[match(name, item) for item in value.op])
524

525
        elif isinstance(value, api_models.Range):
526
            quantity = validate_quantity(name, None, doc_type=doc_type)
527
            return Q('range', **{quantity.search_field: value.dict(
528
529
530
                exclude_unset=True,
            )})

531
532
        elif isinstance(value, (api_models.And, api_models.Or, api_models.Not)):
            return validate_query(value)
533

534
535
        elif isinstance(value, api_models.Nested):
            sub_doc_type = material_entry_type if name == 'entries' else doc_type
536

537
538
            sub_query = validate_api_query(
                value.query, doc_type=sub_doc_type, prefix=name, owner_query=owner_query)
539

540
541
542
543
544
545
            if name in doc_type.nested_object_keys:
                if name == 'entries':
                    sub_query &= owner_query
                return Q('nested', path=name, query=sub_query)
            else:
                return sub_query
546

547
548
549
        # list of values is treated as an "all" over the items
        elif isinstance(value, list):
            return Q('bool', must=[match(name, item) for item in value])
550

551
552
553
554
        elif isinstance(value, dict):
            assert False, (
                'Using dictionaries as criteria values directly is not supported. Use the '
                'Nested model.')
555

556
557
        else:
            return match(name, value)
558

559
560
    if isinstance(query, api_models.And):
        return Q('bool', must=[validate_query(operand) for operand in query.op])
561

562
563
    if isinstance(query, api_models.Or):
        return Q('bool', should=[validate_query(operand) for operand in query.op])
564

565
566
    if isinstance(query, api_models.Not):
        return Q('bool', must_not=validate_query(query.op))
567

568
    if isinstance(query, dict):
569
570
571
572
573
574
        # dictionary is like an "and" of all items in the dict
        if len(query) == 0:
            return Q()

        if len(query) == 1:
            key = next(iter(query))
575
            return validate_criteria(key, query[key])
576
577

        return Q('bool', must=[
578
            validate_criteria(name, value) for name, value in query.items()])
579

580
    raise NotImplementedError()
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603


def validate_pagination(pagination: Pagination, doc_type: DocumentType, loc: List[str] = None):
    order_quantity = None
    if pagination.order_by is not None:
        order_quantity = validate_quantity(
            pagination.order_by, doc_type=doc_type, loc=['pagination', 'order_by'])
        if not order_quantity.definition.is_scalar:
            raise QueryValidationError(
                'the order_by quantity must be a scalar',
                loc=(loc if loc else []) + ['pagination', 'order_by'])

    page_after_value = pagination.page_after_value
    if page_after_value is not None and \
            pagination.order_by is not None and \
            pagination.order_by != doc_type.id_field and \
            ':' not in page_after_value:

        pagination.page_after_value = '%s:' % page_after_value

    return order_quantity, page_after_value


604
605
def _api_to_es_aggregation(
        es_search: Search, name: str, agg: AggregationBase, doc_type: DocumentType) -> A:
606
    '''
607
    Creates an ES aggregation based on the API's aggregation model.
608
609
    '''

610
    agg_name = f'agg:{name}'
611
    es_aggs = es_search.aggs
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630

    if isinstance(agg, StatisticsAggregation):
        for metric_name in agg.metrics:
            metrics = doc_type.metrics
            if metric_name not in metrics and doc_type == material_type:
                metrics = material_entry_type.metrics
            if metric_name not in metrics:
                raise QueryValidationError(
                    'metric must be the qualified name of a suitable search quantity',
                    loc=['statistic', 'metrics'])
            metric_aggregation, metric_quantity = metrics[metric_name]
            es_aggs.metric('statistics:%s' % metric_name, A(
                metric_aggregation,
                field=metric_quantity.qualified_field))

        return

    agg = cast(QuantityAggregation, agg)

631
    longest_nested_key = None
632
633
    quantity = validate_quantity(agg.quantity, doc_type=doc_type, loc=['aggregation', 'quantity'])

634
    for nested_key in doc_type.nested_object_keys:
635
636
        if agg.quantity.startswith(nested_key):
            es_aggs = es_search.aggs.bucket('nested_agg:%s' % name, 'nested', path=nested_key)
637
            longest_nested_key = nested_key
638

639
    es_agg = None
640

641
642
643
    if isinstance(agg, TermsAggregation):
        if not quantity.aggregateable:
            raise QueryValidationError(
644
                'The aggregation quantity cannot be used in a terms aggregation.',
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
                loc=['aggregation', name, 'terms', 'quantity'])

        if agg.pagination is not None:
            if agg.size is not None:
                raise QueryValidationError(
                    f'You cannot paginate and provide an extra size parameter.',
                    loc=['aggregations', name, 'terms', 'pagination'])

            order_quantity, page_after_value = validate_pagination(
                agg.pagination, doc_type=doc_type, loc=['aggregation'])

            # We are using elastic searchs 'composite aggregations' here. We do not really
            # compose aggregations, but only those pseudo composites allow us to use the
            # 'after' feature that allows to scan through all aggregation values.
            terms = A('terms', field=quantity.search_field, order=agg.pagination.order.value)

            if order_quantity is None:
                composite = {
                    'sources': {
                        name: terms
                    },
                    'size': agg.pagination.page_size
                }
668

669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
            else:
                sort_terms = A(
                    'terms',
                    field=order_quantity.search_field,
                    order=agg.pagination.order.value)

                composite = {
                    'sources': [
                        {order_quantity.search_field: sort_terms},
                        {quantity.search_field: terms}
                    ],
                    'size': agg.pagination.page_size
                }

            if page_after_value is not None:
                if order_quantity is None:
                    composite['after'] = {name: page_after_value}
                else:
                    try:
                        order_value, quantity_value = page_after_value.split(':')
                        composite['after'] = {quantity.search_field: quantity_value, order_quantity.search_field: order_value}
                    except Exception:
                        raise QueryValidationError(
                            f'The pager_after_value has not the right format.',
                            loc=['aggregations', name, 'terms', 'pagination', 'page_after_value'])

695
            es_agg = es_aggs.bucket(agg_name, 'composite', **composite)
696
697
698
699
700
701
702

            # additional cardinality to get total
            es_aggs.metric('agg:%s:total' % name, 'cardinality', field=quantity.search_field)
        else:
            if agg.size is None:
                if quantity.default_aggregation_size is not None:
                    agg.size = quantity.default_aggregation_size
703

704
705
                elif quantity.values is not None:
                    agg.size = len(quantity.values)
706

707
708
                else:
                    agg.size = 10
709

710
711
712
            terms_kwargs = {}
            if agg.value_filter is not None:
                terms_kwargs['include'] = '.*%s.*' % agg.value_filter
713

714
            terms = A('terms', field=quantity.search_field, size=agg.size, **terms_kwargs)
715
            es_agg = es_aggs.bucket(agg_name, terms)
716

717
718
719
720
721
722
723
        if agg.entries is not None and agg.entries.size > 0:
            kwargs: Dict[str, Any] = {}
            if agg.entries.required is not None:
                if agg.entries.required.include is not None:
                    kwargs.update(_source=dict(includes=agg.entries.required.include))
                else:
                    kwargs.update(_source=dict(excludes=agg.entries.required.exclude))
724

725
            es_agg.metric('entries', A('top_hits', size=agg.entries.size, **kwargs))
726

727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
    elif isinstance(agg, DateHistogramAggregation):
        if not quantity.annotation.mapping['type'] in ['date']:
            raise QueryValidationError(
                f'The quantity {quantity} cannot be used in a date histogram aggregation',
                loc=['aggregations', name, 'histogram', 'quantity'])

        es_agg = es_aggs.bucket(agg_name, A(
            'date_histogram', field=quantity.search_field, interval=agg.interval,
            format='yyyy-MM-dd'))

    elif isinstance(agg, HistogramAggregation):
        if not quantity.annotation.mapping['type'] in ['integer', 'float', 'double', 'long']:
            raise QueryValidationError(
                f'The quantity {quantity} cannot be used in a histogram aggregation',
                loc=['aggregations', name, 'histogram', 'quantity'])

        es_agg = es_aggs.bucket(agg_name, A(
            'histogram', field=quantity.search_field, interval=agg.interval))

    elif isinstance(agg, MinMaxAggregation):
747
        if not quantity.annotation.mapping['type'] in ['integer', 'float', 'double', 'long', 'date']:
748
749
750
751
752
753
754
            raise QueryValidationError(
                f'The quantity {quantity} cannot be used in a mix-max aggregation',
                loc=['aggregations', name, 'min_max', 'quantity'])

        es_aggs.metric(agg_name + ':min', A('min', field=quantity.search_field))
        es_aggs.metric(agg_name + ':max', A('max', field=quantity.search_field))

755
    else:
756
757
758
759
760
        raise NotImplementedError()

    if isinstance(agg, BucketAggregation):
        for metric_name in agg.metrics:
            metrics = doc_type.metrics
761
            if longest_nested_key == 'entries':
762
763
764
765
766
767
768
769
770
                metrics = material_entry_type.metrics
            if metric_name not in metrics:
                raise QueryValidationError(
                    'metric must be the qualified name of a suitable search quantity',
                    loc=['statistic', 'metrics'])
            metric_aggregation, metric_quantity = metrics[metric_name]
            es_agg.metric('metric:%s' % metric_name, A(
                metric_aggregation,
                field=metric_quantity.qualified_field))
771
772


773
def _es_to_api_aggregation(
774
        es_response, name: str, agg: AggregationBase, doc_type: DocumentType):
775
776
777
778
    '''
    Creates a AggregationResponse from elasticsearch response on a request executed with
    the given aggregation.
    '''
779
    es_aggs = es_response.aggs
780
781
782
783
784
785
786
787
788
789
790
791
    aggregation_dict = agg.dict(by_alias=True)

    if isinstance(agg, StatisticsAggregation):
        metrics = {}
        for metric in agg.metrics:  # type: ignore
            metrics[metric] = es_aggs[f'statistics:{metric}'].value

        return AggregationResponse(
            statistics=StatisticsAggregationResponse(data=metrics, **aggregation_dict))

    agg = cast(QuantityAggregation, agg)
    quantity = validate_quantity(agg.quantity, doc_type=doc_type)
792
    longest_nested_key = None
793
794
795
    for nested_key in doc_type.nested_object_keys:
        if agg.quantity.startswith(nested_key):
            es_aggs = es_response.aggs[f'nested_agg:{name}']
796
            longest_nested_key = nested_key
797

798
    has_no_pagination = getattr(agg, 'pagination', None) is None
799

800
801
    if isinstance(agg, BucketAggregation):
        es_agg = es_aggs['agg:' + name]
802
803
        values = set()

804
805
806
807
808
809
810
        def get_bucket(es_bucket) -> Bucket:
            if has_no_pagination:
                if isinstance(agg, DateHistogramAggregation):
                    value = es_bucket['key_as_string']
                else:
                    value = es_bucket['key']
            elif agg.pagination.order_by is None:  # type: ignore
811
                value = es_bucket.key[name]
812
            else:
813
814
815
816
                value = es_bucket.key[quantity.search_field]

            count = es_bucket.doc_count
            metrics = {}
817
            for metric in agg.metrics:  # type: ignore
818
819
820
821
                metrics[metric] = es_bucket['metric:' + metric].value

            entries = None
            if 'entries' in es_bucket:
822
823
                if longest_nested_key:
                    entries = [{longest_nested_key: item['_source']} for item in es_bucket.entries.hits.hits]
824
825
826
827
                else:
                    entries = [item['_source'] for item in es_bucket.entries.hits.hits]

            values.add(value)
828
829
830
            if len(metrics) == 0:
                metrics = None
            return Bucket(value=value, entries=entries, count=count, metrics=metrics)
831
832
833

        data = [get_bucket(es_bucket) for es_bucket in es_agg.buckets]

834
        if has_no_pagination:
835
836
837
838
            # fill "empty" values
            if quantity.values is not None:
                for value in quantity.values:
                    if value not in values:
839
840
841
842
                        metrics = {metric: 0 for metric in agg.metrics}
                        if len(metrics) == 0:
                            metrics = None
                        data.append(Bucket(value=value, count=0, metrics=metrics))
843

844
        else:
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
            total = es_aggs['agg:%s:total' % name]['value']
            pagination = PaginationResponse(total=total, **aggregation_dict['pagination'])
            if pagination.page_after_value is not None and pagination.page_after_value.endswith(':'):
                pagination.page_after_value = pagination.page_after_value[0:-1]

            if 'after_key' in es_agg:
                after_key = es_agg['after_key']
                if pagination.order_by is None:
                    pagination.next_page_after_value = after_key[name]
                else:
                    str_values = [str(v) for v in after_key.to_dict().values()]
                    pagination.next_page_after_value = ':'.join(str_values)
            else:
                pagination.next_page_after_value = None

            aggregation_dict['pagination'] = pagination
861

862
863
864
865
866
867
868
869
870
871
872
873
        if isinstance(agg, TermsAggregation):
            return AggregationResponse(
                terms=TermsAggregationResponse(data=data, **aggregation_dict))
        elif isinstance(agg, HistogramAggregation):
            return AggregationResponse(
                histogram=HistogramAggregationResponse(data=data, **aggregation_dict))
        elif isinstance(agg, DateHistogramAggregation):
            return AggregationResponse(
                date_histogram=DateHistogramAggregationResponse(data=data, **aggregation_dict))
        else:
            raise NotImplementedError()

874
    if isinstance(agg, MinMaxAggregation):
875
        min_value = es_aggs['agg:%s:min' % name]['value']
876
        max_value = es_aggs['agg:%s:max' % name]['value']
877
878

        return AggregationResponse(
879
            min_max=MinMaxAggregationResponse(data=[min_value, max_value], **aggregation_dict))
880

881
    raise NotImplementedError()
882

883

884
def _specific_agg(agg: Aggregation) -> Union[TermsAggregation, DateHistogramAggregation, HistogramAggregation, MinMaxAggregation, StatisticsAggregation]:
885
886
887
    if agg.terms is not None:
        return agg.terms

888
889
890
891
892
893
894
895
896
    if agg.histogram is not None:
        return agg.histogram

    if agg.date_histogram is not None:
        return agg.date_histogram

    if agg.min_max is not None:
        return agg.min_max

897
898
899
    if agg.statistics is not None:
        return agg.statistics

900
    raise NotImplementedError()
901
902
903
904


def search(
        owner: str = 'public',
905
        query: Union[Query, EsQuery] = None,
906
        pagination: MetadataPagination = None,
907
908
        required: MetadataRequired = None,
        aggregations: Dict[str, Aggregation] = {},
909
        user_id: str = None,
910
        index: Index = entry_index) -> MetadataResponse:
911
912

    # The first half of this method creates the ES query. Then the query is run on ES.
913
    # The second half is about transforming the ES response to a MetadataResponse.
914

915
916
917
918
919
    doc_type = index.doc_type

    # owner and query
    owner_query = _owner_es_query(owner=owner, user_id=user_id, doc_type=doc_type)

920
921
    if query is None:
        query = {}
922
923
924
925

    if isinstance(query, EsQuery):
        es_query = cast(EsQuery, query)
    else:
926
927
        es_query = validate_api_query(
            cast(Query, query), doc_type=doc_type, owner_query=owner_query)
928

929
930
931
932
    if doc_type != entry_type:
        es_query &= Q('nested', path='entries', query=owner_query)
    else:
        es_query &= owner_query
933
934
935

    # pagination
    if pagination is None:
936
        pagination = MetadataPagination()
937

938
    if pagination.order_by is None:
939
        pagination.order_by = doc_type.id_field
940

941
    search = Search(index=index.index_name)
942
943

    search = search.query(es_query)
944
945
946
947
948
    # TODO this depends on doc_type
    if pagination.order_by is None:
        pagination.order_by = doc_type.id_field
    order_quantity, page_after_value = validate_pagination(pagination, doc_type=doc_type)
    order_field = order_quantity.search_field
949
    sort = {order_field: pagination.order.value}
950
951
    if order_field != doc_type.id_field:
        sort[doc_type.id_field] = pagination.order.value
952
    search = search.sort(sort)
953
    search = search.extra(size=pagination.page_size)
954
955
956
957
958
959

    if pagination.page_offset:
        search = search.extra(**{'from': pagination.page_offset})
    elif pagination.page:
        search = search.extra(**{'from': (pagination.page - 1) * pagination.page_size})
    elif page_after_value:
960
        search = search.extra(search_after=page_after_value.rsplit(':', 1))
961
962
963

    # required
    if required:
964
965
966
967
968
969
        for list_ in [required.include, required.exclude]:
            for quantity in [] if list_ is None else list_:
                # TODO validate quantities with wildcards
                if '*' not in quantity:
                    validate_quantity(quantity, doc_type=doc_type, loc=['required'])

970
971
972
973
        if required.include is not None and pagination.order_by not in required.include:
            required.include.append(pagination.order_by)
        if required.exclude is not None and pagination.order_by in required.exclude:
            required.exclude.remove(pagination.order_by)
974
975
976
977
978
979
980

        if required.include is not None and doc_type.id_field not in required.include:
            required.include.append(doc_type.id_field)

        if required.exclude is not None and doc_type.id_field in required.exclude:
            required.exclude.remove(doc_type.id_field)

981
982
983
984
        search = search.source(includes=required.include, excludes=required.exclude)

    # aggregations
    for name, agg in aggregations.items():
985
        _api_to_es_aggregation(search, name, _specific_agg(agg), doc_type=doc_type)
986
987
988
989
990
991
992
993
994
995

    # execute
    try:
        es_response = search.execute()
    except RequestError as e:
        raise SearchError(e)

    more_response_data = {}

    # pagination
996
    next_page_after_value = None
997
998
    if 0 < len(es_response.hits) < es_response.hits.total:
        last = es_response.hits[-1]
999
1000
        if order_field == doc_type.id_field:
            next_page_after_value = last[doc_type.id_field]
For faster browsing, not all history is shown. View entire blame