search.py 42.7 KB
Newer Older
Markus Scheidgen's avatar
Markus Scheidgen committed
1
2
3
4
#
# Copyright The NOMAD Authors.
#
# This file is part of NOMAD. See https://nomad-lab.eu for further info.
5
6
7
8
9
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
Markus Scheidgen's avatar
Markus Scheidgen committed
10
#     http://www.apache.org/licenses/LICENSE-2.0
11
12
#
# Unless required by applicable law or agreed to in writing, software
Markus Scheidgen's avatar
Markus Scheidgen committed
13
# distributed under the License is distributed on an "AS IS" BASIS,
14
15
16
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Markus Scheidgen's avatar
Markus Scheidgen committed
17
#
18

19
'''
20
This module represents calculations in elastic search.
21
'''
22

23
24
from typing import Iterable, Dict, List, Any
from elasticsearch_dsl import Search, Q, A, analyzer, tokenizer
25
import elasticsearch.helpers
26
from elasticsearch.exceptions import NotFoundError, RequestError
Markus Scheidgen's avatar
Markus Scheidgen committed
27
from datetime import datetime
28
import json
29

30
from nomad.datamodel.material import Material
31
from nomad import config, datamodel, infrastructure, utils
32
33
34
35
36
37
from nomad.metainfo.search_extension import (  # pylint: disable=unused-import
    search_quantities, metrics, order_default_quantities, groups)
from nomad.app_fastapi import models as api_models
from nomad.app_fastapi.models import (
    Pagination, PaginationResponse, Query, MetadataRequired, SearchResponse, Aggregation,
    Statistic, StatisticResponse, AggregationOrderType, AggregationResponse, AggregationDataItem)
38

39

40
41
42
43
path_analyzer = analyzer(
    'path_analyzer',
    tokenizer=tokenizer('path_tokenizer', 'pattern', pattern='/'))

44
45

class AlreadyExists(Exception): pass
46
47


48
49
50
class ElasticSearchError(Exception): pass


51
52
53
class AuthenticationRequiredError(Exception): pass


Markus Scheidgen's avatar
Markus Scheidgen committed
54
55
56
class ScrollIdNotFound(Exception): pass


57
58
59
class InvalidQuery(Exception): pass


60
entry_document = datamodel.EntryMetadata.m_def.a_elastic.document
61
material_document = Material.m_def.a_elastic.document
62
63
64

for domain in datamodel.domains:
    order_default_quantities.setdefault(domain, order_default_quantities.get('__all__'))
65
66


67
def delete_upload(upload_id):
68
    ''' Delete all entries with given ``upload_id`` from the index. '''
69
    index = entry_document._default_index()
70
71
72
    Search(index=index).query('match', upload_id=upload_id).delete()


73
def delete_entry(calc_id):
74
    ''' Delete the entry with the given ``calc_id`` from the index. '''
75
    index = entry_document._default_index()
76
77
78
    Search(index=index).query('match', calc_id=calc_id).delete()


79
80
def publish(calcs: Iterable[datamodel.EntryMetadata]) -> None:
    ''' Update all given calcs with their metadata and set ``publish = True``. '''
81
82
    def elastic_updates():
        for calc in calcs:
83
            entry = calc.a_elastic.create_index_entry()
84
            entry.published = True
85
86
87
88
89
            entry = entry.to_dict(include_meta=True)
            source = entry.pop('_source')
            entry['doc'] = source
            entry['_op_type'] = 'update'
            yield entry
90
91

    elasticsearch.helpers.bulk(infrastructure.elastic_client, elastic_updates())
92
93
94
    refresh()


95
96
def index_all(calcs: Iterable[datamodel.EntryMetadata], do_refresh=True) -> None:
    '''
97
98
99
100
    Adds all given calcs with their metadata to the index.

    Returns:
        Number of failed entries.
101
    '''
102
103
    def elastic_updates():
        for calc in calcs:
104
            entry = calc.a_elastic.create_index_entry()
105
106
107
108
            entry = entry.to_dict(include_meta=True)
            entry['_op_type'] = 'index'
            yield entry

109
    _, failed = elasticsearch.helpers.bulk(infrastructure.elastic_client, elastic_updates(), stats_only=True)
Markus Scheidgen's avatar
Markus Scheidgen committed
110

Markus Scheidgen's avatar
Markus Scheidgen committed
111
    if do_refresh:
Markus Scheidgen's avatar
Markus Scheidgen committed
112
113
        refresh()

114
    return failed
115
116


117
118
def refresh():
    infrastructure.elastic_client.indices.refresh(config.elastic.index_name)
119
120


121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def _owner_es_query(owner: str, user_id: str = None):
    if owner == 'all':
        q = Q('term', published=True)
        if user_id is not None:
            q = q | Q('term', owners__user_id=user_id)
    elif owner == 'public':
        q = Q('term', published=True) & Q('term', with_embargo=False)
    elif owner == 'visible':
        q = Q('term', published=True) & Q('term', with_embargo=False)
        if user_id is not None:
            q = q | Q('term', owners__user_id=user_id)
    elif owner == 'shared':
        if user_id is None:
            raise AuthenticationRequiredError('Authentication required for owner value shared.')

        q = Q('term', owners__user_id=user_id)
    elif owner == 'user':
        if user_id is None:
            raise AuthenticationRequiredError('Authentication required for owner value user.')

        q = Q('term', uploader__user_id=user_id)
    elif owner == 'staging':
        if user_id is None:
            raise AuthenticationRequiredError('Authentication required for owner value user')
        q = Q('term', published=False) & Q('term', owners__user_id=user_id)
    elif owner == 'admin':
        if user_id is None or not datamodel.User.get(user_id=user_id).is_admin:
            raise AuthenticationRequiredError('This can only be used by the admin user.')
        q = None
    else:
        raise KeyError('Unsupported owner value')

    if q is not None:
        return q
    return Q()


158
class SearchRequest:
159
    '''
160
161
    Represents a search request and allows to execute that request.
    It allows to compose the following features: a query;
162
163
164
    statistics (metrics and aggregations); quantity values; scrolling, pagination for entries;
    scrolling for quantity values.

165
166
167
    The query part filters NOMAD data before the other features come into effect. There
    are specialized methods for configuring the :func:`owner` and :func:`time_range` queries.
    Quantity's can be search for by setting them as attributes.
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186

    The aggregations for statistics can be requested for pre-configured quantities. These
    bucket aggregations come with a metric calculated for each each possible
    quantity value.

    The other possible form of aggregations, allows to get quantity values as results
    (e.g. get all datasets, get all users, etc.). Each value can be accompanied by metrics
    (over all entries with that value) and an example value.

    Of course, searches can return a set of search results. Search objects can be
    configured with pagination or scrolling for these results. Pagination is the default
    and also allows ordering of results. Scrolling can be used if all entries need to be
    'scrolled through'. This might be necessary, since elastic search has limits on
    possible pages (e.g. 'from' must by smaller than 10000). On the downside, there is no
    ordering on scrolling.

    There is also scrolling for quantities to go through all quantity values. There is no
    paging for aggregations.
    '''
187
    def __init__(self, domain: str = config.meta.default_domain, query=None):
188
        self._domain = domain
189
190
        self._query = query
        self._search = Search(index=config.elastic.index_name)
191

192
    def domain(self, domain: str = None):
193
        '''
194
195
        Applies the domain of this request to the query. Allows to optionally update
        the domain of this request.
196
        '''
197
198
199
200
201
202
        if domain is not None:
            self._domain = domain

        self.q = self.q & Q('term', domain=self._domain)
        return self

203
    def owner(self, owner_type: str = 'all', user_id: str = None):
204
        '''
205
206
        Uses the query part of the search to restrict the results based on the owner.
        The possible types are: ``all`` for all calculations; ``public`` for
207
208
209
210
        calculations visible by everyone, excluding embargo-ed entries and entries only visible
        to the given user; ``visible`` all data that is visible by the user, excluding
        embargo-ed entries from other users; ``user`` for all calculations of to the given
        user; ``staging`` for all calculations in staging of the given user.
211
212
213
214

        Arguments:
            owner_type: The type of the owner query, see above.
            user_id: The 'owner' given as the user's unique id.
215
216
217
218
219

        Raises:
            KeyError: If the given owner_type is not supported
            ValueError: If the owner_type requires a user but none is given, or the
                given user is not allowed to use the given owner_type.
220
        '''
221
        self.q &= _owner_es_query(owner=owner_type, user_id=user_id)
222
        return self
223

224
    def search_parameters(self, **kwargs):
225
        '''
226
227
        Configures the existing query with additional search parameters. Kwargs are
        interpreted as key value pairs. Keys have to coresspond to valid entry quantities
228
229
        in the domain's (DFT calculations) datamodel. Alternatively search parameters
        can be set via attributes.
230
        '''
231
232
        for name, value in kwargs.items():
            self.search_parameter(name, value)
Markus Scheidgen's avatar
Markus Scheidgen committed
233

234
        return self
Markus Scheidgen's avatar
Markus Scheidgen committed
235

236
    def _search_parameter_to_es(self, name, value):
237
        quantity = search_quantities[name]
238

239
        if quantity.many and not isinstance(value, list):
240
241
            value = [value]

242
        if quantity.many_or and isinstance(value, List):
243
            return Q('terms', **{quantity.search_field: value})
244

245
246
        if quantity.derived:
            if quantity.many and not isinstance(value, list):
247
                value = [value]
248
            value = quantity.derived(value)
249

250
        if isinstance(value, list):
251
            values = value
252
        else:
253
            values = [value]
254

255
256
257
        return Q('bool', must=[
            Q('match', **{quantity.search_field: item})
            for item in values])
258

259
260
    def search_parameter(self, name, value):
        self.q &= self._search_parameter_to_es(name, value)
261
        return self
262

263
    def query(self, query):
264
        ''' Adds the given query as a 'and' (i.e. 'must') clause to the request. '''
265
        self._query &= query
266

267
        return self
268

269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
    def query_expression(self, expression) -> 'SearchRequest':

        bool_operators = {'$and': 'must', '$or': 'should', '$not': 'must_not'}
        comp_operators = {'$%s' % op: op for op in ['gt', 'gte', 'lt', 'lte']}

        def _to_es(key, values):
            if key in bool_operators:
                if isinstance(values, dict):
                    values = [values]
                assert isinstance(values, list), 'bool operator requires a list of dicts or dict'
                child_es_queries = [
                    _to_es(child_key, child_value)
                    for child_query in values
                    for child_key, child_value in child_query.items()]
                return Q('bool', **{bool_operators[key]: child_es_queries})

            if key in comp_operators:
                assert isinstance(values, dict), 'comparison operator requires a dict'
                assert len(values) == 1, 'comparison operator requires exactly one quantity'
                quantity_name, value = next(iter(values.items()))
                quantity = search_quantities.get(quantity_name)
                assert quantity is not None, 'quantity %s does not exist' % quantity_name
                return Q('range', **{quantity.search_field: {comp_operators[key]: value}})

            try:
                return self._search_parameter_to_es(key, values)
            except KeyError:
                assert False, 'quantity %s does not exist' % key

        if len(expression) == 0:
            self.q &= Q()
        else:
            self.q &= Q('bool', must=[_to_es(key, value) for key, value in expression.items()])

        return self
Alvin Noe Ladines's avatar
Alvin Noe Ladines committed
304

305
    def time_range(self, start: datetime, end: datetime):
306
        ''' Adds a time range to the query. '''
307
308
        if start is None and end is None:
            return self
309

310
311
312
313
        if start is None:
            start = datetime.fromtimestamp(0)
        if end is None:
            end = datetime.utcnow()
314

315
        self.q &= Q('range', upload_time=dict(gte=start, lte=end))
316

317
        return self
318

319
320
    @property
    def q(self):
321
        ''' The underlying elasticsearch_dsl query object '''
322
323
        if self._query is None:
            return Q('match_all')
324
325
        else:
            return self._query
326

327
    @q.setter
328
    def q(self, q):
329
        self._query = q
330

331
    def totals(self, metrics_to_use: List[str] = []):
332
        '''
333
        Configure the request to return overall totals for the given metrics.
334

335
336
337
        The statics are returned with the other quantity statistics under the pseudo
        quantity name 'total'. 'total' contains the pseudo value 'all'. It is used to
        store the metrics aggregated over all entries in the search results.
338
        '''
339
340
        self._add_metrics(self._search.aggs, metrics_to_use)
        return self
341

342
    def statistics(self, statistics: List[str], metrics_to_use: List[str] = []):
343
        '''
344
        Configures the domain's default statistics.
345
        '''
346
347
        for statistic in statistics:
            search_quantity = search_quantities[statistic]
348
            statistic_order = search_quantity.statistic_order
349
            self.statistic(
350
351
                search_quantity.qualified_name,
                search_quantity.statistic_size,
352
                metrics_to_use=metrics_to_use,
353
                order={statistic_order: 'asc' if statistic_order == '_key' else 'desc'})
354

355
        return self
356

357
358
    def statistic(
            self, quantity_name: str, size: int, metrics_to_use: List[str] = [],
359
            order: Dict[str, str] = dict(_key='asc'), include: str = None):
360
        '''
361
362
363
364
365
366
367
368
369
370
371
372
373
        This can be used to display statistics over the searched entries and allows to
        implement faceted search on the top values for each quantity.

        The metrics contain overall and per quantity value sums of code runs (calcs),
        unique code runs, datasets, and additional domain specific metrics
        (e.g. total energies, and unique geometries for DFTcalculations). The quantities
        that can be aggregated to metrics are defined in module:`datamodel`. Aggregations
        and respective metrics are calculated for aggregations given in ``aggregations``
        and metrics in ``aggregation_metrics``. As a pseudo aggregation ``total_metrics``
        are calculation over all search results. The ``aggregations`` gives tuples of
        quantities and default aggregation sizes.

        The search results will contain a dictionary ``statistics``. This has a key
374
        for each configured quantity. Each quantity key will hold a dict
375
376
        with a key for each quantity value. Each quantity value key will hold a dict
        with a key for each metric. The values will be the actual aggregated metric values.
377

378
        Arguments:
379
            quantity_name: The quantity to aggregate statistics for. Only works on *keyword* field.
380
381
382
            metrics_to_use: The metrics calculated over the aggregations. Can be
                ``unique_code_runs``, ``datasets``, other domain specific metrics.
                The basic doc_count metric ``code_runs`` is always given.
383
            order: The order dictionary is passed to the elastic search aggregation.
384
385
386
            include:
                Uses an regular expression in ES to only return values that include
                the given substring.
387
388
        '''
        quantity = search_quantities[quantity_name]
389
390
391
392
        terms_kwargs = {}
        if include is not None:
            terms_kwargs['include'] = '.*%s.*' % include
        terms = A('terms', field=quantity.search_field, size=size, order=order, **terms_kwargs)
393

394
395
        buckets = self._search.aggs.bucket('statistics:%s' % quantity_name, terms)
        self._add_metrics(buckets, metrics_to_use)
396

397
        return self
398

399
400
401
    def _add_metrics(self, parent=None, metrics_to_use: List[str] = []):
        if parent is None:
            parent = self._search.aggs
402

403
        for metric in metrics_to_use:
404
            metric_quantity = metrics[metric]
405
            field = metric_quantity.search_field
406
407
408
            parent.metric(
                'metric:%s' % metric_quantity.metric_name,
                A(metric_quantity.metric, field=field))
409

410
    def date_histogram(self, metrics_to_use: List[str] = [], interval: str = '1M'):
411
        '''
412
        Adds a date histogram on the given metrics to the statistics part.
413
        '''
414
        histogram = A('date_histogram', field='upload_time', interval=interval, format='yyyy-MM-dd')
415
        self._add_metrics(self._search.aggs.bucket('statistics:date_histogram', histogram), metrics_to_use)
416

417
        return self
418

419
    def quantities(self, **kwargs):
420
        '''
421
422
        Shorthand for adding multiple quantities. See :func:`quantity`. Keywork argument
        keys are quantity name, values are tuples of size and after value.
423
        '''
424
425
426
        for name, spec in kwargs:
            size, after = spec
            self.quantity(name, after=after, size=size)
427

428
        return self
429

430
431
432
    def quantity(
            self, name, size=100, after=None, examples=0, examples_source=None,
            order_by: str = None, order: str = 'desc'):
433
        '''
434
        Adds a requests for values of the given quantity.
435
436
437
        It allows to scroll through all values via elasticsearch's
        composite aggregations. The response will contain the quantity values and
        an example entry for each value.
438

439
440
        This can be used to implement continues scrolling through authors, datasets,
        or uploads within the searched entries.
441

442
443
        If one or more quantities are specified,
        the search results will contain a dictionary ``quantities``. The keys are quantity
444
445
        name the values dictionary with 'after' and 'values' key.
        The 'values' key holds a dict with all the values as keys and their entry count
446
447
        as values (i.e. number of entries with that value).

448
        Arguments:
449
            name: The quantity name. Must be in :data:`quantities`.
450
451
452
453
454
455
            after: The 'after' value allows to scroll over various requests, by providing
                the 'after' value of the last search. The 'after' value is part of the
                response. Use ``None`` in the first request.
            size:
                The size gives the ammount of maximum values in the next scroll window.
                If the size is None, a maximum of 100 quantity values will be requested.
456
457
458
            examples:
                Number of results to return that has each value
            order_by:
459
                A sortable quantity that should be used to order. By default, the max of each
460
461
462
                value bucket is used.
            order:
                "desc" or "asc"
463
        '''
464
465
466
        if size is None:
            size = 100

467
        quantity = search_quantities[name]
468
        terms = A('terms', field=quantity.search_field)
469

470
471
472
        # We are using elastic searchs 'composite aggregations' here. We do not really
        # compose aggregations, but only those pseudo composites allow us to use the
        # 'after' feature that allows to scan through all aggregation values.
473
474
475
        if order_by is None:
            composite = dict(sources={name: terms}, size=size)
        else:
Markus Scheidgen's avatar
Markus Scheidgen committed
476
            sort_terms = A('terms', field=order_by, order=order)
477
            composite = dict(sources=[{order_by: sort_terms}, {name: terms}], size=size)
478
        if after is not None:
479
480
481
482
483
484
            if order_by is None:
                composite['after'] = {name: after}
            else:
                composite['after'] = {order_by: after, name: ''}

        composite_agg = self._search.aggs.bucket('quantity:%s' % name, 'composite', **composite)
485

486
        if examples > 0:
487
            kwargs: Dict[str, Any] = {}
488
489
            if examples_source is not None:
                kwargs.update(_source=dict(includes=examples_source))
490

491
            composite_agg.metric('examples', A('top_hits', size=examples, **kwargs))
492

493
        return self
494

495
496
497
498
499
    def global_statistics(self):
        '''
        Adds general statistics to the request. The results will have a key called
        global_statistics.
        '''
Markus Scheidgen's avatar
Markus Scheidgen committed
500
        self.owner('public')
501
502
503
        self._search.aggs.metric(
            'global_statistics:n_entries', A('value_count', field='calc_id'))
        self._search.aggs.metric(
Markus Scheidgen's avatar
Markus Scheidgen committed
504
            'global_statistics:n_uploads', A('cardinality', field='upload_id'))
505
506
507
508
        self._search.aggs.metric(
            'global_statistics:n_calculations', A('sum', field='dft.n_calculations'))
        self._search.aggs.metric(
            'global_statistics:n_quantities', A('sum', field='dft.n_quantities'))
Markus Scheidgen's avatar
About    
Markus Scheidgen committed
509
510
        self._search.aggs.metric(
            'global_statistics:n_materials', A('cardinality', field='encyclopedia.material.material_id'))
511
512
513

        return self

514
    def exclude(self, *args):
515
        ''' Exclude certain elastic fields from the search results. '''
516
517
518
        self._search = self._search.source(excludes=args)
        return self

519
    def include(self, *args):
520
        ''' Include only the given fields in the search results. '''
521
522
523
        self._search = self._search.source(includes=args)
        return self

524
    def execute(self):
525
        '''
526
        Executes without returning actual results. Only makes sense if the request
527
        was configured for statistics or quantity values.
528
        '''
529
530
531
        search = self._search.query(self.q)[0:0]
        response = search.execute()
        return self._response(response)
532

533
    def execute_scan(self, order_by: str = None, order: int = -1, **kwargs):
534
        '''
535
536
        This execute the search as scan. The result will be a generator over the found
        entries. Everything but the query part of this object, will be ignored.
537
        '''
538
539
540
        search = self._search.query(self.q)

        if order_by is not None:
541
            order_by_quantity = search_quantities[order_by]
542
543

            if order == 1:
544
                search = search.sort(order_by_quantity.search_field)
545
            else:
546
                search = search.sort('-%s' % order_by_quantity.search_field)  # pylint: disable=no-member
547

548
549
            search = search.params(preserve_order=True)

550
        for hit in search.params(**kwargs).scan():
551
            yield hit.to_dict()
552

553
    def execute_paginated(
554
            self, page: int = 1, per_page=10, order_by: str = None,
555
            order: int = -1):
556
        '''
557
558
559
560
561
562
563
        Executes the search and returns paginated results. Those are sorted.

        Arguments:
            page: The requested page, starts with 1.
            per_page: The number of entries per page.
            order_by: The quantity to order by.
            order: -1 or 1 for descending or ascending order.
564
        '''
565
        if order_by is None:
566
567
568
            order_by_quantity = order_default_quantities[self._domain]
        else:
            order_by_quantity = search_quantities[order_by]
569

570
        search = self._search.query(self.q)
571
572

        if order == 1:
573
            search = search.sort(order_by_quantity.search_field)
574
        else:
575
            search = search.sort('-%s' % order_by_quantity.search_field)  # pylint: disable=no-member
576
        search = search[(page - 1) * per_page: page * per_page]
577

Markus Scheidgen's avatar
Markus Scheidgen committed
578
        es_result = search.execute()
579

Markus Scheidgen's avatar
Markus Scheidgen committed
580
581
        result = self._response(es_result, with_hits=True)

582
        result.update(pagination=dict(total=result['total'], page=page, per_page=per_page))
583
584
        return result

585
586
587
    def execute_scrolled(
            self, scroll_id: str = None, size: int = 1000, scroll: str = u'5m',
            order_by: str = None, order: int = -1):
588
        '''
589
590
        Executes a scrolling search. based on ES scroll API. Pagination is replaced with
        scrolling, no ordering is available, no statistics, no quantities will be provided.
591

592
593
594
595
        Scrolling is done by calling this function again and again with the same ``scroll_id``.
        Each time, this function will return the next batch of search results. If the
        ``scroll_id`` is not available anymore, a new ``scroll_id`` is assigned and scrolling
        starts from the beginning again.
596

597
598
        The response will contain a 'scroll' part with attributes 'total', 'scroll_id',
        and 'size'.
599

600
601
602
603
604
605
        Arguments:
            scroll_id: The scroll id to receive the next batch from. None will create a new
                scroll.
            size: The batch size in number of hits.
            scroll: The time the scroll should be kept alive (i.e. the time between requests
                to this method) in ES time units. Default is 5 minutes.
606
607

        TODO support order and order_by
608
        '''
609
        es = infrastructure.elastic_client
610

611
612
613
        if scroll_id is None:
            # initiate scroll
            resp = es.search(  # pylint: disable=E1123
614
                body=self._search.query(self.q).to_dict(), scroll=scroll, size=size,
615
                index=config.elastic.index_name)
616

617
618
619
620
            scroll_id = resp.get('_scroll_id')
            if scroll_id is None:
                # no results for search query
                return dict(scroll=dict(total=0, size=size), results=[])
621

622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
        else:
            try:
                resp = es.scroll(scroll_id, scroll=scroll)  # pylint: disable=E1123
            except NotFoundError:
                raise ScrollIdNotFound()

        total = resp['hits']['total']
        results = list(hit['_source'] for hit in resp['hits']['hits'])

        # since we are using the low level api here, we should check errors
        if resp["_shards"]["successful"] < resp["_shards"]["total"]:
            utils.get_logger(__name__).error('es operation was unsuccessful on at least one shard')
            raise ElasticSearchError('es operation was unsuccessful on at least one shard')

        if len(results) == 0:
            es.clear_scroll(body={'scroll_id': [scroll_id]}, ignore=(404, ))  # pylint: disable=E1123
            scroll_id = None

640
        scroll_info = dict(total=total, size=size, scroll=True)
641
642
643
644
645
        if scroll_id is not None:
            scroll_info.update(scroll_id=scroll_id)

        return dict(scroll=scroll_info, results=results)

646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
    def execute_aggregated(
            self, after: str = None, per_page: int = 1000, includes: List[str] = None):
        '''
        Uses a composite aggregation on top of the search to go through the result
        set. This allows to go arbirarely deep without using scroll. But, it will
        only return results with ``upload_id``, ``calc_id`` and the given
        quantities. The results will be 'ordered' by ``upload_id``.

        Arguments:
            after: The key that determines the start of the current page. This after
                key is returned with each response. Use None (default) for the first
                request.
            per_page: The size of each page.
            includes: A list of quantity names that should be returned in addition to
                ``upload_id`` and ``calc_id``.
        '''
        upload_id_agg = A('terms', field="upload_id")
        calc_id_agg = A('terms', field="calc_id")

        composite = dict(
            sources=[dict(upload_id=upload_id_agg), dict(calc_id=calc_id_agg)],
            size=per_page)

        if after is not None:
            upload_id, calc_id = after.split(':')
            composite['after'] = dict(upload_id=upload_id, calc_id=calc_id)

        composite_agg = self._search.aggs.bucket('ids', 'composite', **composite)
        if includes is not None:
            composite_agg.metric('examples', A('top_hits', size=1, _source=dict(includes=includes)))

        search = self._search.query(self.q)[0:0]
        response = search.execute()

        ids = response['aggregations']['ids']
        if 'after_key' in ids:
            after_dict = ids['after_key']
            after = '%s:%s' % (after_dict['upload_id'], after_dict['calc_id'])
        else:
            after = None

        id_agg_info = dict(total=response['hits']['total'], after=after, per_page=per_page)

        def transform_result(es_result):
            result = dict(
                upload_id=es_result['key']['upload_id'],
                calc_id=es_result['key']['calc_id'])

            if includes is not None:
                source = es_result['examples']['hits']['hits'][0]['_source']
                for key in source:
                    result[key] = source[key]

            return result

        results = [
            transform_result(item) for item in ids['buckets']]

        return dict(aggregation=id_agg_info, results=results)

706
    def _response(self, response, with_hits: bool = False) -> Dict[str, Any]:
707
        '''
708
        Prepares a response object covering the total number of results, hits, statistics,
709
710
        and quantities. Other aspects like pagination and scrolling have to be added
        elsewhere.
711
        '''
712
        result: Dict[str, Any] = dict()
713
        aggs = response.aggregations.to_dict()
714

715
        # total
716
717
718
719
        total = response.hits.total if hasattr(response, 'hits') else 0
        result.update(total=total)

        # hits
720
721
        if len(response.hits) > 0 or with_hits:
            result.update(results=[hit.to_dict() for hit in response.hits])
722
723
724

        # statistics
        def get_metrics(bucket, code_runs):
725
            result = {}
726
            # TODO optimize ... go through the buckets not the metrics
727
            for metric in metrics:
728
729
730
                agg_name = 'metric:%s' % metric
                if agg_name in bucket:
                    result[metric] = bucket[agg_name]['value']
731
            result.update(code_runs=code_runs)
732
733
            return result

734
        statistics_results = {
735
            quantity_name[11:]: {
736
                str(bucket['key']): get_metrics(bucket, bucket['doc_count'])
737
                for bucket in quantity['buckets']
738
            }
739
            for quantity_name, quantity in aggs.items()
740
741
            if quantity_name.startswith('statistics:')
        }
742

743
744
745
746
747
748
749
750
751
        # global statistics
        global_statistics_results = {
            agg_name[18:]: agg.get('value')
            for agg_name, agg in aggs.items()
            if agg_name.startswith('global_statistics:')
        }
        if len(global_statistics_results) > 0:
            result.update(global_statistics=global_statistics_results)

752
753
754
        # totals
        totals_result = get_metrics(aggs, total)
        statistics_results['total'] = dict(all=totals_result)
Markus Scheidgen's avatar
Markus Scheidgen committed
755

756
757
        if len(statistics_results) > 0:
            result.update(statistics=statistics_results)
Markus Scheidgen's avatar
Markus Scheidgen committed
758

759
        # quantities
760
        def create_quantity_result(quantity_name, quantity):
761
762
763
764
765
766
767
            values = {}
            for bucket in quantity['buckets']:
                value = dict(
                    total=bucket['doc_count'])
                if 'examples' in bucket:
                    examples = [hit['_source'] for hit in bucket['examples']['hits']['hits']]
                    value.update(examples=examples)
Markus Scheidgen's avatar
Markus Scheidgen committed
768

769
                values[bucket['key'][quantity_name]] = value
770

771
            result = dict(values=values)
772
            if 'after_key' in quantity:
773
774
775
776
777
778
779
780
                after = quantity['after_key']
                if len(after) == 1:
                    result.update(after=after[quantity_name])
                else:
                    for key in after:
                        if key != quantity_name:
                            result.update(after=after[key])
                            break
781

782
            return result
783

784
        quantity_results = {
785
786
787
            quantity_name[9:]: create_quantity_result(quantity_name[9:], quantity)
            for quantity_name, quantity in aggs.items()
            if quantity_name.startswith('quantity:')
Markus Scheidgen's avatar
Markus Scheidgen committed
788
789
        }

790
791
        if len(quantity_results) > 0:
            result.update(quantities=quantity_results)
792
793

        return result
794

795
796
    def __str__(self):
        return json.dumps(self._search.to_dict(), indent=2)
797
798


799
def flat(obj, prefix=None):
800
    '''
801
802
    Helper that translates nested result objects into flattened dicts with
    ``domain.quantity`` as keys.
803
    '''
804
805
806
807
    if isinstance(obj, dict):
        result = {}
        for key, value in obj.items():
            if isinstance(value, dict):
808
                value = flat(value)
809
                for child_key, child_value in value.items():
810
                    result['%s.%s' % (key, child_key)] = child_value
811
812
813
814
815
816
817

            else:
                result[key] = value

        return result
    else:
        return obj
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000


def _api_to_es_query(query: api_models.Query) -> Q:
    '''
    Creates an ES query based on the API's query model. This needs to be a normalized
    query expression with explicit objects for logical, set, and comparison operators.
    Shorthand notations ala ``quantity:operator`` are not supported here; this
    needs to be resolved via the respective pydantic validator. There is also no
    validation of quantities and types.
    '''
    def quantity_to_es(name: str, value: api_models.Value) -> Q:
        # TODO depends on keyword or not, value might need normalization, etc.
        quantity = search_quantities[name]
        return Q('match', **{quantity.search_field: value})

    def parameter_to_es(name: str, value: api_models.QueryParameterValue) -> Q:

        if isinstance(value, api_models.All):
            return Q('bool', must=[
                quantity_to_es(name, item)
                for item in value.op])

        if isinstance(value, api_models.Any_):
            return Q('bool', should=[
                quantity_to_es(name, item)
                for item in value.op])

        if isinstance(value, api_models.None_):
            return Q('bool', must_not=[
                quantity_to_es(name, item)
                for item in value.op])

        if isinstance(value, api_models.ComparisonOperator):
            quantity = search_quantities[name]
            return Q('range', **{quantity.search_field: {
                type(value).__name__.lower(): value.op}})

        # list of values is treated as an "all" over the items
        if isinstance(value, list):
            return Q('bool', must=[
                quantity_to_es(name, item)
                for item in value])

        return quantity_to_es(name, value)

    def query_to_es(query: api_models.Query) -> Q:
        if isinstance(query, api_models.LogicalOperator):
            if isinstance(query, api_models.And):
                return Q('bool', must=[query_to_es(operand) for operand in query.op])

            if isinstance(query, api_models.Or):
                return Q('bool', should=[query_to_es(operand) for operand in query.op])

            if isinstance(query, api_models.Not):
                return Q('bool', must_not=query_to_es(query.op))

            raise NotImplementedError()

        if not isinstance(query, dict):
            raise NotImplementedError()

        # dictionary is like an "and" of all items in the dict
        if len(query) == 0:
            return Q()

        if len(query) == 1:
            key = next(iter(query))
            return parameter_to_es(key, query[key])

        return Q('bool', must=[
            parameter_to_es(name, value) for name, value in query.items()])

    return query_to_es(query)


def _api_to_es_statistic(es_search: Search, name: str, statistic: Statistic) -> A:
    '''
    Creates an ES aggregation based on the API's statistic model.
    '''

    quantity = search_quantities[statistic.quantity.value]
    if quantity.statistic_values is not None:
        statistic.size = len(quantity.statistic_values)

    terms_kwargs = {}
    if statistic.value_filter is not None:
        terms_kwargs['include'] = '.*%s.*' % statistic.value_filter

    order_type = '_count' if statistic.order.type_ == AggregationOrderType.entries else '_key'
    statistic_agg = es_search.aggs.bucket('statistic:%s' % name, A(
        'terms',
        field=quantity.search_field,
        size=statistic.size,
        order={order_type: statistic.order.direction.value},
        **terms_kwargs))

    for metric in statistic.metrics:
        metric_quantity = metrics[metric.value]
        statistic_agg.metric('metric:%s' % metric_quantity.metric_name, A(
            metric_quantity.metric,
            field=metric_quantity.search_field))


def _es_to_api_statistics(es_response, name: str, statistic: Statistic) -> StatisticResponse:
    '''
    Creates a StatisticResponse from elasticsearch response on a request executed with
    the given statistics.
    '''
    quantity = search_quantities[statistic.quantity.value]

    es_statistic = es_response.aggs['statistic:' + name]
    statistic_data = {}
    for bucket in es_statistic.buckets:
        value_data = dict(entries=bucket.doc_count)
        for metric in statistic.metrics:
            value_data[metric.value] = bucket['metric:' + metric.value].value
        statistic_data[bucket.key] = value_data

    if quantity.statistic_values is not None:
        for value in quantity.statistic_values:
            if value not in statistic_data:
                statistic_data[value] = dict(entries=0, **{
                    metric.value: 0 for metric in statistic.metrics})

    return StatisticResponse(data=statistic_data, **statistic.dict(by_alias=True))


def _api_to_es_aggregation(es_search: Search, name: str, agg: Aggregation) -> A:
    '''
    Creates an ES aggregation based on the API's aggregation model.
    '''
    quantity = search_quantities[agg.quantity.value]
    terms = A('terms', field=quantity.search_field, order=agg.pagination.order.value)

    # We are using elastic searchs 'composite aggregations' here. We do not really
    # compose aggregations, but only those pseudo composites allow us to use the
    # 'after' feature that allows to scan through all aggregation values.
    order_by = agg.pagination.order_by
    if order_by is None:
        composite = dict(sources={name: terms}, size=agg.pagination.size)
    else:
        order_quantity = search_quantities[order_by.value]
        sort_terms = A('terms', field=order_quantity.search_field, order=agg.pagination.order.value)
        composite = dict(sources=[{order_by.value: sort_terms}, {quantity.name: terms}], size=agg.pagination.size)

    if agg.pagination.after is not None:
        if order_by is None:
            composite['after'] = {name: agg.pagination.after}
        else:
            order_value, quantity_value = agg.pagination.after.split(':')
            composite['after'] = {quantity.name: quantity_value, order_quantity.name: order_value}

    composite_agg = es_search.aggs.bucket('agg:%s' % name, 'composite', **composite)

    if agg.entries is not None and agg.entries.size > 0:
        kwargs: Dict[str, Any] = {}
        if agg.entries.required is not None:
            if agg.entries.required.include is not None:
                kwargs.update(_source=dict(includes=agg.entries.required.include))
            else:
                kwargs.update(_source=dict(excludes=agg.entries.required.exclude))

        composite_agg.metric('entries', A('top_hits', size=agg.entries.size, **kwargs))

    # additional cardinality to get total
    es_search.aggs.metric('agg:%s:total' % name, 'cardinality', field=quantity.search_field)


def _es_to_api_aggregation(es_response, name: str, agg: Aggregation) -> AggregationResponse:
    '''
    Creates a AggregationResponse from elasticsearch response on a request executed with
    the given aggregation.
    '''
    order_by = agg.pagination.order_by
    quantity = search_quantities[agg.quantity.value]
    es_agg = es_response.aggs['agg:' + name]

    def get_entries(agg):
        if 'entries' in agg:
            return [item['_source'] for item in agg.entries.hits.hits]
        else:
            return None

For faster browsing, not all history is shown. View entire blame