search.py 38.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# Copyright The NOMAD Authors.
#
# This file is part of NOMAD. See https://nomad-lab.eu for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
'''
This module provides an interface to elasticsearch. Other parts of NOMAD must not
interact with elasticsearch to maintain a clear coherent interface and allow for change.

Currently NOMAD uses one entry index and two distinct materials indices. The entries
index is based on two different mappings, once used by the old flask api (v0) and one
used by the new fastapi api (v1). The mappings are used at the same time and the documents
are merged. Write operations (index, publish, edit, lift embargo, delete) are common; defined
here in the module ``__init__.py``. Read operations are different and
should be used as per use-case directly from the ``v0`` and ``v1`` submodules.

Most common functions also take an ``update_materials`` keyword arg with allows to
update the v1 materials index according to the performed changes. TODO this is only
partially implemented.
'''

from typing import Union, List, Iterable, Any, cast, Dict, Generator
import json
import elasticsearch
from elasticsearch.exceptions import TransportError, RequestError
from elasticsearch_dsl import Q, A, Search
40
41
from elasticsearch_dsl.query import Query as EsQuery
from pydantic.error_wrappers import ErrorWrapper
42

43
44
45
from nomad import config, infrastructure, utils
from nomad import datamodel
from nomad.datamodel import EntryArchive, EntryMetadata
46
from nomad.app.v1 import models as api_models
47
from nomad.app.v1.models import (
48
49
50
    AggregationPagination, MetadataPagination, Pagination, PaginationResponse,
    QuantityAggregation, Query, MetadataRequired,
    MetadataResponse, Aggregation, StatisticsAggregation, StatisticsAggregationResponse,
51
52
    Value, AggregationBase, TermsAggregation, BucketAggregation, HistogramAggregation,
    DateHistogramAggregation, MinMaxAggregation, Bucket,
53
    MinMaxAggregationResponse, TermsAggregationResponse, HistogramAggregationResponse,
54
    DateHistogramAggregationResponse, AggregationResponse)
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
from nomad.metainfo.elasticsearch_extension import (
    index_entries, entry_type, entry_index, DocumentType,
    material_type, entry_type, material_entry_type,
    entry_index, Index, index_entries, DocumentType, SearchQuantity)


def update_by_query(
        update_script: str,
        query: Any = None,
        owner: str = None,
        user_id: str = None,
        index: str = None,
        refresh: bool = False,
        **kwargs):
    '''
    Uses the given painless script to update the entries by given query.

    In most cases, the elasticsearch entry index should not be updated field by field;
    you should run `index` instead and fully replace documents from mongodb and
    archive files.

    This method provides a faster direct method to update individual fields, e.g. to quickly
    update fields for editing operations.
    '''
    if query is None:
        query = {}
    es_query = _api_to_es_query(query)
    if owner is not None:
        es_query &= _owner_es_query(owner=owner, user_id=user_id)

    body = {
        'script': {
            'source': update_script,
            'lang': 'painless'
        },
        'query': es_query.to_dict()
    }

    body['script'].update(**kwargs)

    try:
        result = infrastructure.elastic_client.update_by_query(
            body=body, index=config.elastic.entries_index)
    except TransportError as e:
        utils.get_logger(__name__).error(
            'es update_by_query script error', exc_info=e,
            es_info=json.dumps(e.info, indent=2))
        raise SearchError(e)

    if refresh:
        _refresh()

    return result


def delete_by_query(
        query: dict,
        owner: str = None,
        user_id: str = None,
        update_materials: bool = False,
        refresh: bool = False):
    '''
    Deletes all entries that match the given query.
    '''
    if query is None:
        query = {}
    es_query = _api_to_es_query(query)
    es_query &= _owner_es_query(owner=owner, user_id=user_id)

    body = {
        'query': es_query.to_dict()
    }

    try:
        result = infrastructure.elastic_client.delete_by_query(
            body=body, index=config.elastic.entries_index)
    except TransportError as e:
        utils.get_logger(__name__).error(
            'es delete_by_query error', exc_info=e,
            es_info=json.dumps(e.info, indent=2))
        raise SearchError(e)

    if refresh:
        _refresh()

    if update_materials:
        # TODO update the matrials index at least for v1
        pass

    return result


def refresh():
    '''
    Refreshes the specified indices.
    '''

    try:
        infrastructure.elastic_client.indices.refresh(index=config.elastic.entries_index)
    except TransportError as e:
        utils.get_logger(__name__).error(
            'es delete_by_query error', exc_info=e,
            es_info=json.dumps(e.info, indent=2))
        raise SearchError(e)


_refresh = refresh


def index(
        entries: Union[EntryArchive, List[EntryArchive]],
        update_materials: bool = False,
        refresh: bool = True):
    '''
    Index the given entries based on their archive. Either creates or updates the underlying
    elasticsearch documents. If an underlying elasticsearch document already exists it
    will be fully replaced.
    '''
    if not isinstance(entries, list):
        entries = [entries]

    index_entries(entries=entries, update_materials=update_materials)

    if refresh:
        _refresh()


# TODO this depends on how we merge section metadata
def publish(entries: Iterable[EntryMetadata], index: str = None) -> int:
    '''
    Publishes the given entries based on their entry metadata. Sets publishes to true,
    and updates most user provided metadata with a partial update. Returns the number
    of failed updates.
    '''
    return update_metadata(
        entries, index=index, published=True, update_materials=True, refresh=True)


def update_metadata(
        entries: Iterable[EntryMetadata], index: str = None,
        update_materials: bool = False, refresh: bool = False,
        **kwargs) -> int:
    '''
    Update all given entries with their given metadata. Additionally apply kwargs.
    Returns the number of failed updates. This is doing a partial update on the underlying
    elasticsearch documents.
    '''

    def elastic_updates():
        for entry_metadata in entries:
            entry_archive = entry_metadata.m_parent
            if entry_archive is None:
                entry_archive = EntryArchive(metadata=entry_metadata)
            entry_doc = entry_type.create_index_doc(entry_archive)

            entry_doc.update(**kwargs)

            yield dict(
                doc=entry_doc,
                _id=entry_metadata.calc_id,
                _type=entry_index.doc_type.name,
                _index=entry_index.index_name,
                _op_type='update')

    updates = list(elastic_updates())
    _, failed = elasticsearch.helpers.bulk(
        infrastructure.elastic_client, updates, stats_only=True)

    if update_materials:
        # TODO update the matrials index at least for v1
        pass

    if refresh:
        _refresh()

    return failed


def delete_upload(upload_id: str, refresh: bool = False, **kwargs):
    '''
    Deletes the given upload.
    '''
    delete_by_query(query=dict(upload_id=upload_id), **kwargs)

    if refresh:
        _refresh()


def delete_entry(entry_id: str, index: str = None, refresh: bool = False, **kwargs):
    '''
    Deletes the given entry.
    '''
    delete_by_query(query=dict(calc_id=entry_id), **kwargs)

    if refresh:
        _refresh()


class SearchError(Exception): pass


class AuthenticationRequiredError(Exception): pass


_entry_metadata_defaults = {
    quantity.name: quantity.default
    for quantity in datamodel.EntryMetadata.m_def.quantities  # pylint: disable=not-an-iterable
    if quantity.default not in [None, [], False, 0]
}


def _es_to_entry_dict(hit, required: MetadataRequired = None) -> Dict[str, Any]:
    '''
    Elasticsearch entry metadata does not contain default values, if a metadata is not
    set. This will add default values to entry metadata in dict form obtained from
    elasticsearch.
    '''
    entry_dict = hit.to_dict()
    for key, value in _entry_metadata_defaults.items():
        if key not in entry_dict:
            if required is not None:
                if required.exclude and key in required.exclude:
                    continue
                if required.include and key not in required.include:
                    continue

            entry_dict[key] = value
282

283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
    return entry_dict


def _api_to_es_query(query: api_models.Query) -> Q:
    '''
    Creates an ES query based on the API's query model. This needs to be a normalized
    query expression with explicit objects for logical, set, and comparison operators.
    Shorthand notations ala ``quantity:operator`` are not supported here; this
    needs to be resolved via the respective pydantic validator. There is also no
    validation of quantities and types.
    '''
    def quantity_to_es(name: str, value: api_models.Value) -> Q:
        # TODO depends on keyword or not, value might need normalization, etc.
        quantity = entry_type.quantities[name]
        return Q('match', **{quantity.search_field: value})

    def parameter_to_es(name: str, value: api_models.QueryParameterValue) -> Q:

        if isinstance(value, api_models.All):
            return Q('bool', must=[
                quantity_to_es(name, item)
                for item in value.op])

        if isinstance(value, api_models.Any_):
            return Q('bool', should=[
                quantity_to_es(name, item)
                for item in value.op])

        if isinstance(value, api_models.None_):
            return Q('bool', must_not=[
                quantity_to_es(name, item)
                for item in value.op])

        if isinstance(value, api_models.Range):
            quantity = entry_type.quantities[name]
            return Q('range', **{quantity.search_field: value.dict(
                exclude_unset=True,
            )})

        # list of values is treated as an "all" over the items
        if isinstance(value, list):
            return Q('bool', must=[
                quantity_to_es(name, item)
                for item in value])

        return quantity_to_es(name, cast(api_models.Value, value))

    def query_to_es(query: api_models.Query) -> Q:
        if isinstance(query, api_models.LogicalOperator):
            if isinstance(query, api_models.And):
                return Q('bool', must=[query_to_es(operand) for operand in query.op])

            if isinstance(query, api_models.Or):
                return Q('bool', should=[query_to_es(operand) for operand in query.op])

            if isinstance(query, api_models.Not):
                return Q('bool', must_not=query_to_es(query.op))

            raise NotImplementedError()

        if not isinstance(query, dict):
            raise NotImplementedError()

        # dictionary is like an "and" of all items in the dict
        if len(query) == 0:
            return Q()

        if len(query) == 1:
            key = next(iter(query))
            return parameter_to_es(key, query[key])

        return Q('bool', must=[
            parameter_to_es(name, value) for name, value in query.items()])

    return query_to_es(query)


def _owner_es_query(owner: str, user_id: str = None, doc_type: DocumentType = entry_type):
    def term_query(**kwargs):
        prefix = '' if doc_type == entry_type else 'entries.'
        return Q('term', **{
            (prefix + field): value for field, value in kwargs.items()})

    if owner == 'all':
        q = term_query(published=True)
        if user_id is not None:
            q = q | term_query(owners__user_id=user_id)
    elif owner == 'public':
        q = term_query(published=True) & term_query(with_embargo=False)
    elif owner == 'visible':
        q = term_query(published=True) & term_query(with_embargo=False)
        if user_id is not None:
            q = q | term_query(owners__user_id=user_id)
    elif owner == 'shared':
        if user_id is None:
            raise AuthenticationRequiredError('Authentication required for owner value shared.')

        q = term_query(owners__user_id=user_id)
    elif owner == 'user':
        if user_id is None:
            raise AuthenticationRequiredError('Authentication required for owner value user.')

        q = term_query(uploader__user_id=user_id)
    elif owner == 'staging':
        if user_id is None:
            raise AuthenticationRequiredError('Authentication required for owner value user')
        q = term_query(published=False) & term_query(owners__user_id=user_id)
    elif owner == 'admin':
        if user_id is None or not datamodel.User.get(user_id=user_id).is_admin:
            raise AuthenticationRequiredError('This can only be used by the admin user.')
        q = None
    elif owner is None:
        q = None
    else:
        raise KeyError('Unsupported owner value')

    if q is not None:
        return q
    return Q()
402
403


404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
class QueryValidationError(Exception):
    def __init__(self, error, loc):
        self.errors = [ErrorWrapper(Exception(error), loc=loc)]


def validate_quantity(
        quantity_name: str, value: Value = None, doc_type: DocumentType = None,
        loc: List[str] = None) -> SearchQuantity:
    '''
    Validates the given quantity name and value against the given document type.

    Returns:
        A metainfo elasticsearch extension SearchQuantity object.

    Raises: QueryValidationError
    '''
    assert quantity_name is not None

422
423
424
    if doc_type == material_entry_type and not quantity_name.startswith('entries'):
        quantity_name = f'entries.{quantity_name}'

425
426
427
    if doc_type == material_type and quantity_name.startswith('entries'):
        doc_type = material_entry_type

428
429
430
431
432
433
434
435
436
437
438
439
    if doc_type is None:
        doc_type = entry_type

    quantity = doc_type.quantities.get(quantity_name)
    if quantity is None:
        raise QueryValidationError(
            f'{quantity_name} is not a {doc_type} quantity',
            loc=[quantity_name] if loc is None else loc)

    return quantity


440
441
442
def validate_api_query(
        query: Query, doc_type: DocumentType, owner_query: EsQuery,
        prefix: str = None) -> EsQuery:
443
444
445
446
447
448
449
450
451
452
453
454
    '''
    Creates an ES query based on the API's query model. This needs to be a normalized
    query expression with explicit objects for logical, set, and comparison operators.
    Shorthand notations ala ``quantity:operator`` are not supported here; this
    needs to be resolved via the respective pydantic validator.

    However, this function performs validation of quantities and types and raises
    a QueryValidationError accordingly. This exception is populated with pydantic
    errors.

    Arguments:
        query: The api query object.
455
456
457
458
459
460
461
462
        doc_type:
            The elasticsearch metainfo extension document type that this query needs to
            be verified against.
        owner_query:
            A prebuild ES query that is added to nested entries query. Only for
            materials queries.
        prefix:
            An optional prefix that is added to all quantity names. Used for recursion.
463
464
465
466
467
468

    Returns:
        A elasticsearch dsl query object.

    Raises: QueryValidationError
    '''
469
470

    def match(name: str, value: Value) -> EsQuery:
471
472
        if name == 'optimade_filter':
            value = str(value)
473
            from nomad.app.optimade import filterparser
474
            try:
475
                return filterparser.parse_filter(value, without_prefix=True)
476
477
478
479
480
481

            except filterparser.FilterException as e:
                raise QueryValidationError(
                    f'Could not parse optimade filter: {e}',
                    loc=[name])

482
        # TODO non keyword quantities, quantities with value transformation, type checks
483
484
485
        quantity = validate_quantity(name, value, doc_type=doc_type)
        return Q('match', **{quantity.search_field: value})

486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
    def validate_query(query: Query) -> EsQuery:
        return validate_api_query(
            query, doc_type=doc_type, owner_query=owner_query, prefix=prefix)

    def validate_criteria(name: str, value: Any):
        if prefix is not None:
            name = f'{prefix}.{name}'

        # handle prefix and nested queries
        for nested_key in doc_type.nested_object_keys:
            if len(name) < len(nested_key):
                break
            if not name.startswith(nested_key):
                continue
            if prefix is not None and prefix.startswith(nested_key):
                continue
            if nested_key == name and isinstance(value, api_models.Nested):
                continue

            value = api_models.Nested(query={name[len(nested_key) + 1:]: value})
            name = nested_key
            break
508
509

        if isinstance(value, api_models.All):
510
            return Q('bool', must=[match(name, item) for item in value.op])
511

512
        elif isinstance(value, api_models.Any_):
513
            return Q('bool', should=[match(name, item) for item in value.op])
514

515
        elif isinstance(value, api_models.None_):
516
            return Q('bool', must_not=[match(name, item) for item in value.op])
517

518
        elif isinstance(value, api_models.Range):
519
            quantity = validate_quantity(name, None, doc_type=doc_type)
520
            return Q('range', **{quantity.search_field: value.dict(
521
522
523
                exclude_unset=True,
            )})

524
525
        elif isinstance(value, (api_models.And, api_models.Or, api_models.Not)):
            return validate_query(value)
526

527
528
        elif isinstance(value, api_models.Nested):
            sub_doc_type = material_entry_type if name == 'entries' else doc_type
529

530
531
            sub_query = validate_api_query(
                value.query, doc_type=sub_doc_type, prefix=name, owner_query=owner_query)
532

533
534
535
536
537
538
            if name in doc_type.nested_object_keys:
                if name == 'entries':
                    sub_query &= owner_query
                return Q('nested', path=name, query=sub_query)
            else:
                return sub_query
539

540
541
542
        # list of values is treated as an "all" over the items
        elif isinstance(value, list):
            return Q('bool', must=[match(name, item) for item in value])
543

544
545
546
547
        elif isinstance(value, dict):
            assert False, (
                'Using dictionaries as criteria values directly is not supported. Use the '
                'Nested model.')
548

549
550
        else:
            return match(name, value)
551

552
553
    if isinstance(query, api_models.And):
        return Q('bool', must=[validate_query(operand) for operand in query.op])
554

555
556
    if isinstance(query, api_models.Or):
        return Q('bool', should=[validate_query(operand) for operand in query.op])
557

558
559
    if isinstance(query, api_models.Not):
        return Q('bool', must_not=validate_query(query.op))
560

561
    if isinstance(query, dict):
562
563
564
565
566
567
        # dictionary is like an "and" of all items in the dict
        if len(query) == 0:
            return Q()

        if len(query) == 1:
            key = next(iter(query))
568
            return validate_criteria(key, query[key])
569
570

        return Q('bool', must=[
571
            validate_criteria(name, value) for name, value in query.items()])
572

573
    raise NotImplementedError()
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596


def validate_pagination(pagination: Pagination, doc_type: DocumentType, loc: List[str] = None):
    order_quantity = None
    if pagination.order_by is not None:
        order_quantity = validate_quantity(
            pagination.order_by, doc_type=doc_type, loc=['pagination', 'order_by'])
        if not order_quantity.definition.is_scalar:
            raise QueryValidationError(
                'the order_by quantity must be a scalar',
                loc=(loc if loc else []) + ['pagination', 'order_by'])

    page_after_value = pagination.page_after_value
    if page_after_value is not None and \
            pagination.order_by is not None and \
            pagination.order_by != doc_type.id_field and \
            ':' not in page_after_value:

        pagination.page_after_value = '%s:' % page_after_value

    return order_quantity, page_after_value


597
598
def _api_to_es_aggregation(
        es_search: Search, name: str, agg: AggregationBase, doc_type: DocumentType) -> A:
599
    '''
600
    Creates an ES aggregation based on the API's aggregation model.
601
602
    '''

603
    agg_name = f'agg:{name}'
604
    es_aggs = es_search.aggs
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623

    if isinstance(agg, StatisticsAggregation):
        for metric_name in agg.metrics:
            metrics = doc_type.metrics
            if metric_name not in metrics and doc_type == material_type:
                metrics = material_entry_type.metrics
            if metric_name not in metrics:
                raise QueryValidationError(
                    'metric must be the qualified name of a suitable search quantity',
                    loc=['statistic', 'metrics'])
            metric_aggregation, metric_quantity = metrics[metric_name]
            es_aggs.metric('statistics:%s' % metric_name, A(
                metric_aggregation,
                field=metric_quantity.qualified_field))

        return

    agg = cast(QuantityAggregation, agg)

624
    longest_nested_key = None
625
626
    quantity = validate_quantity(agg.quantity, doc_type=doc_type, loc=['aggregation', 'quantity'])

627
    for nested_key in doc_type.nested_object_keys:
628
629
        if agg.quantity.startswith(nested_key):
            es_aggs = es_search.aggs.bucket('nested_agg:%s' % name, 'nested', path=nested_key)
630
            longest_nested_key = nested_key
631

632
    es_agg = None
633

634
635
636
    if isinstance(agg, TermsAggregation):
        if not quantity.aggregateable:
            raise QueryValidationError(
637
                'The aggregation quantity cannot be used in a terms aggregation.',
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
                loc=['aggregation', name, 'terms', 'quantity'])

        if agg.pagination is not None:
            if agg.size is not None:
                raise QueryValidationError(
                    f'You cannot paginate and provide an extra size parameter.',
                    loc=['aggregations', name, 'terms', 'pagination'])

            order_quantity, page_after_value = validate_pagination(
                agg.pagination, doc_type=doc_type, loc=['aggregation'])

            # We are using elastic searchs 'composite aggregations' here. We do not really
            # compose aggregations, but only those pseudo composites allow us to use the
            # 'after' feature that allows to scan through all aggregation values.
            terms = A('terms', field=quantity.search_field, order=agg.pagination.order.value)

            if order_quantity is None:
                composite = {
                    'sources': {
                        name: terms
                    },
                    'size': agg.pagination.page_size
                }
661

662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
            else:
                sort_terms = A(
                    'terms',
                    field=order_quantity.search_field,
                    order=agg.pagination.order.value)

                composite = {
                    'sources': [
                        {order_quantity.search_field: sort_terms},
                        {quantity.search_field: terms}
                    ],
                    'size': agg.pagination.page_size
                }

            if page_after_value is not None:
                if order_quantity is None:
                    composite['after'] = {name: page_after_value}
                else:
                    try:
                        order_value, quantity_value = page_after_value.split(':')
                        composite['after'] = {quantity.search_field: quantity_value, order_quantity.search_field: order_value}
                    except Exception:
                        raise QueryValidationError(
                            f'The pager_after_value has not the right format.',
                            loc=['aggregations', name, 'terms', 'pagination', 'page_after_value'])

688
            es_agg = es_aggs.bucket(agg_name, 'composite', **composite)
689
690
691
692
693
694
695

            # additional cardinality to get total
            es_aggs.metric('agg:%s:total' % name, 'cardinality', field=quantity.search_field)
        else:
            if agg.size is None:
                if quantity.default_aggregation_size is not None:
                    agg.size = quantity.default_aggregation_size
696

697
698
                elif quantity.values is not None:
                    agg.size = len(quantity.values)
699

700
701
                else:
                    agg.size = 10
702

703
704
705
            terms_kwargs = {}
            if agg.value_filter is not None:
                terms_kwargs['include'] = '.*%s.*' % agg.value_filter
706

707
            terms = A('terms', field=quantity.search_field, size=agg.size, **terms_kwargs)
708
            es_agg = es_aggs.bucket(agg_name, terms)
709

710
711
712
713
714
715
716
        if agg.entries is not None and agg.entries.size > 0:
            kwargs: Dict[str, Any] = {}
            if agg.entries.required is not None:
                if agg.entries.required.include is not None:
                    kwargs.update(_source=dict(includes=agg.entries.required.include))
                else:
                    kwargs.update(_source=dict(excludes=agg.entries.required.exclude))
717

718
            es_agg.metric('entries', A('top_hits', size=agg.entries.size, **kwargs))
719

720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
    elif isinstance(agg, DateHistogramAggregation):
        if not quantity.annotation.mapping['type'] in ['date']:
            raise QueryValidationError(
                f'The quantity {quantity} cannot be used in a date histogram aggregation',
                loc=['aggregations', name, 'histogram', 'quantity'])

        es_agg = es_aggs.bucket(agg_name, A(
            'date_histogram', field=quantity.search_field, interval=agg.interval,
            format='yyyy-MM-dd'))

    elif isinstance(agg, HistogramAggregation):
        if not quantity.annotation.mapping['type'] in ['integer', 'float', 'double', 'long']:
            raise QueryValidationError(
                f'The quantity {quantity} cannot be used in a histogram aggregation',
                loc=['aggregations', name, 'histogram', 'quantity'])

        es_agg = es_aggs.bucket(agg_name, A(
            'histogram', field=quantity.search_field, interval=agg.interval))

    elif isinstance(agg, MinMaxAggregation):
740
        if not quantity.annotation.mapping['type'] in ['integer', 'float', 'double', 'long', 'date']:
741
742
743
744
745
746
747
            raise QueryValidationError(
                f'The quantity {quantity} cannot be used in a mix-max aggregation',
                loc=['aggregations', name, 'min_max', 'quantity'])

        es_aggs.metric(agg_name + ':min', A('min', field=quantity.search_field))
        es_aggs.metric(agg_name + ':max', A('max', field=quantity.search_field))

748
    else:
749
750
751
752
753
        raise NotImplementedError()

    if isinstance(agg, BucketAggregation):
        for metric_name in agg.metrics:
            metrics = doc_type.metrics
754
            if longest_nested_key == 'entries':
755
756
757
758
759
760
761
762
763
                metrics = material_entry_type.metrics
            if metric_name not in metrics:
                raise QueryValidationError(
                    'metric must be the qualified name of a suitable search quantity',
                    loc=['statistic', 'metrics'])
            metric_aggregation, metric_quantity = metrics[metric_name]
            es_agg.metric('metric:%s' % metric_name, A(
                metric_aggregation,
                field=metric_quantity.qualified_field))
764
765


766
def _es_to_api_aggregation(
767
        es_response, name: str, agg: AggregationBase, doc_type: DocumentType):
768
769
770
771
    '''
    Creates a AggregationResponse from elasticsearch response on a request executed with
    the given aggregation.
    '''
772
    es_aggs = es_response.aggs
773
774
775
776
777
778
779
780
781
782
783
784
    aggregation_dict = agg.dict(by_alias=True)

    if isinstance(agg, StatisticsAggregation):
        metrics = {}
        for metric in agg.metrics:  # type: ignore
            metrics[metric] = es_aggs[f'statistics:{metric}'].value

        return AggregationResponse(
            statistics=StatisticsAggregationResponse(data=metrics, **aggregation_dict))

    agg = cast(QuantityAggregation, agg)
    quantity = validate_quantity(agg.quantity, doc_type=doc_type)
785
    longest_nested_key = None
786
787
788
    for nested_key in doc_type.nested_object_keys:
        if agg.quantity.startswith(nested_key):
            es_aggs = es_response.aggs[f'nested_agg:{name}']
789
            longest_nested_key = nested_key
790

791
    has_no_pagination = getattr(agg, 'pagination', None) is None
792

793
794
    if isinstance(agg, BucketAggregation):
        es_agg = es_aggs['agg:' + name]
795
796
        values = set()

797
798
799
800
801
802
803
        def get_bucket(es_bucket) -> Bucket:
            if has_no_pagination:
                if isinstance(agg, DateHistogramAggregation):
                    value = es_bucket['key_as_string']
                else:
                    value = es_bucket['key']
            elif agg.pagination.order_by is None:  # type: ignore
804
                value = es_bucket.key[name]
805
            else:
806
807
808
809
                value = es_bucket.key[quantity.search_field]

            count = es_bucket.doc_count
            metrics = {}
810
            for metric in agg.metrics:  # type: ignore
811
812
813
814
                metrics[metric] = es_bucket['metric:' + metric].value

            entries = None
            if 'entries' in es_bucket:
815
816
                if longest_nested_key:
                    entries = [{longest_nested_key: item['_source']} for item in es_bucket.entries.hits.hits]
817
818
819
820
                else:
                    entries = [item['_source'] for item in es_bucket.entries.hits.hits]

            values.add(value)
821
822
823
            if len(metrics) == 0:
                metrics = None
            return Bucket(value=value, entries=entries, count=count, metrics=metrics)
824
825
826

        data = [get_bucket(es_bucket) for es_bucket in es_agg.buckets]

827
        if has_no_pagination:
828
829
830
831
            # fill "empty" values
            if quantity.values is not None:
                for value in quantity.values:
                    if value not in values:
832
833
834
835
                        metrics = {metric: 0 for metric in agg.metrics}
                        if len(metrics) == 0:
                            metrics = None
                        data.append(Bucket(value=value, count=0, metrics=metrics))
836

837
        else:
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
            total = es_aggs['agg:%s:total' % name]['value']
            pagination = PaginationResponse(total=total, **aggregation_dict['pagination'])
            if pagination.page_after_value is not None and pagination.page_after_value.endswith(':'):
                pagination.page_after_value = pagination.page_after_value[0:-1]

            if 'after_key' in es_agg:
                after_key = es_agg['after_key']
                if pagination.order_by is None:
                    pagination.next_page_after_value = after_key[name]
                else:
                    str_values = [str(v) for v in after_key.to_dict().values()]
                    pagination.next_page_after_value = ':'.join(str_values)
            else:
                pagination.next_page_after_value = None

            aggregation_dict['pagination'] = pagination
854

855
856
857
858
859
860
861
862
863
864
865
866
        if isinstance(agg, TermsAggregation):
            return AggregationResponse(
                terms=TermsAggregationResponse(data=data, **aggregation_dict))
        elif isinstance(agg, HistogramAggregation):
            return AggregationResponse(
                histogram=HistogramAggregationResponse(data=data, **aggregation_dict))
        elif isinstance(agg, DateHistogramAggregation):
            return AggregationResponse(
                date_histogram=DateHistogramAggregationResponse(data=data, **aggregation_dict))
        else:
            raise NotImplementedError()

867
    if isinstance(agg, MinMaxAggregation):
868
        min_value = es_aggs['agg:%s:min' % name]['value']
869
        max_value = es_aggs['agg:%s:max' % name]['value']
870
871

        return AggregationResponse(
872
            min_max=MinMaxAggregationResponse(data=[min_value, max_value], **aggregation_dict))
873

874
    raise NotImplementedError()
875

876

877
def _specific_agg(agg: Aggregation) -> Union[TermsAggregation, DateHistogramAggregation, HistogramAggregation, MinMaxAggregation, StatisticsAggregation]:
878
879
880
    if agg.terms is not None:
        return agg.terms

881
882
883
884
885
886
887
888
889
    if agg.histogram is not None:
        return agg.histogram

    if agg.date_histogram is not None:
        return agg.date_histogram

    if agg.min_max is not None:
        return agg.min_max

890
891
892
    if agg.statistics is not None:
        return agg.statistics

893
    raise NotImplementedError()
894
895
896
897


def search(
        owner: str = 'public',
898
        query: Union[Query, EsQuery] = None,
899
        pagination: MetadataPagination = None,
900
901
        required: MetadataRequired = None,
        aggregations: Dict[str, Aggregation] = {},
902
        user_id: str = None,
903
        index: Index = entry_index) -> MetadataResponse:
904
905

    # The first half of this method creates the ES query. Then the query is run on ES.
906
    # The second half is about transforming the ES response to a MetadataResponse.
907

908
909
910
911
912
    doc_type = index.doc_type

    # owner and query
    owner_query = _owner_es_query(owner=owner, user_id=user_id, doc_type=doc_type)

913
914
    if query is None:
        query = {}
915
916
917
918

    if isinstance(query, EsQuery):
        es_query = cast(EsQuery, query)
    else:
919
920
        es_query = validate_api_query(
            cast(Query, query), doc_type=doc_type, owner_query=owner_query)
921

922
923
924
925
    if doc_type != entry_type:
        es_query &= Q('nested', path='entries', query=owner_query)
    else:
        es_query &= owner_query
926
927
928

    # pagination
    if pagination is None:
929
        pagination = MetadataPagination()
930

931
    if pagination.order_by is None:
932
        pagination.order_by = doc_type.id_field
933

934
    search = Search(index=index.index_name)
935
936

    search = search.query(es_query)
937
938
939
940
941
    # TODO this depends on doc_type
    if pagination.order_by is None:
        pagination.order_by = doc_type.id_field
    order_quantity, page_after_value = validate_pagination(pagination, doc_type=doc_type)
    order_field = order_quantity.search_field
942
    sort = {order_field: pagination.order.value}
943
944
    if order_field != doc_type.id_field:
        sort[doc_type.id_field] = pagination.order.value
945
    search = search.sort(sort)
946
    search = search.extra(size=pagination.page_size)
947
948
949
950
951
952

    if pagination.page_offset:
        search = search.extra(**{'from': pagination.page_offset})
    elif pagination.page:
        search = search.extra(**{'from': (pagination.page - 1) * pagination.page_size})
    elif page_after_value:
953
        search = search.extra(search_after=page_after_value.rsplit(':', 1))
954
955
956

    # required
    if required:
957
958
959
960
961
962
        for list_ in [required.include, required.exclude]:
            for quantity in [] if list_ is None else list_:
                # TODO validate quantities with wildcards
                if '*' not in quantity:
                    validate_quantity(quantity, doc_type=doc_type, loc=['required'])

963
964
965
966
        if required.include is not None and pagination.order_by not in required.include:
            required.include.append(pagination.order_by)
        if required.exclude is not None and pagination.order_by in required.exclude:
            required.exclude.remove(pagination.order_by)
967
968
969
970
971
972
973

        if required.include is not None and doc_type.id_field not in required.include:
            required.include.append(doc_type.id_field)

        if required.exclude is not None and doc_type.id_field in required.exclude:
            required.exclude.remove(doc_type.id_field)

974
975
976
977
        search = search.source(includes=required.include, excludes=required.exclude)

    # aggregations
    for name, agg in aggregations.items():
978
        _api_to_es_aggregation(search, name, _specific_agg(agg), doc_type=doc_type)
979
980
981
982
983
984
985
986
987
988

    # execute
    try:
        es_response = search.execute()
    except RequestError as e:
        raise SearchError(e)

    more_response_data = {}

    # pagination
989
    next_page_after_value = None
990
    if 0 < len(es_response.hits) < es_response.hits.total and len(es_response.hits) >= pagination.page_size:
991
        last = es_response.hits[-1]
992
993
        if order_field == doc_type.id_field:
            next_page_after_value = last[doc_type.id_field]
994
        else:
995
996
997
998
            # after_value is not necessarily the value stored in the field
            # itself: internally ES can perform the sorting on a different
            # value which is reported under meta.sort.
            after_value = last.meta.sort[0]
999
            next_page_after_value = '%s:%s' % (after_value, last[doc_type.id_field])
1000
    pagination_response = PaginationResponse(
For faster browsing, not all history is shown. View entire blame