search.py 46.6 KB
Newer Older
Markus Scheidgen's avatar
Markus Scheidgen committed
1
2
3
4
#
# Copyright The NOMAD Authors.
#
# This file is part of NOMAD. See https://nomad-lab.eu for further info.
5
6
7
8
9
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
Markus Scheidgen's avatar
Markus Scheidgen committed
10
#     http://www.apache.org/licenses/LICENSE-2.0
11
12
#
# Unless required by applicable law or agreed to in writing, software
Markus Scheidgen's avatar
Markus Scheidgen committed
13
# distributed under the License is distributed on an "AS IS" BASIS,
14
15
16
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Markus Scheidgen's avatar
Markus Scheidgen committed
17
#
18

19
'''
20
This module represents calculations in elastic search.
21
'''
22

23
from typing import cast, Iterable, Dict, List, Any
24
from elasticsearch_dsl import Search, Q, A, analyzer, tokenizer
25
import elasticsearch.helpers
26
from elasticsearch.exceptions import NotFoundError, RequestError, TransportError
Markus Scheidgen's avatar
Markus Scheidgen committed
27
from datetime import datetime
28
import json
29

30
from nomad.datamodel.material import Material
31
from nomad import config, datamodel, infrastructure, utils
32
33
from nomad.metainfo.search_extension import (  # pylint: disable=unused-import
    search_quantities, metrics, order_default_quantities, groups)
34
35
from nomad.app.v1 import models as api_models
from nomad.app.v1.models import (
David Sikter's avatar
David Sikter committed
36
    EntryPagination, PaginationResponse, Query, MetadataRequired, SearchResponse, Aggregation,
37
    Statistic, StatisticResponse, AggregationOrderType, AggregationResponse, AggregationDataItem)
38

39

40
41
_entry_metadata_defaults = {
    quantity.name: quantity.default
Markus Scheidgen's avatar
Markus Scheidgen committed
42
    for quantity in datamodel.EntryMetadata.m_def.quantities  # pylint: disable=not-an-iterable
43
44
45
    if quantity.default not in [None, [], False, 0]
}

Markus Scheidgen's avatar
Markus Scheidgen committed
46
47
48
49
50
_all_author_quantities = [
    quantity.name
    for quantity in datamodel.EntryMetadata.m_def.all_quantities.values()
    if quantity.type in [datamodel.user_reference, datamodel.author_reference]]

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68

def _es_to_entry_dict(hit, required: MetadataRequired) -> Dict[str, Any]:
    '''
    Elasticsearch entry metadata does not contain default values, if a metadata is not
    set. This will add default values to entry metadata in dict form obtained from
    elasticsearch.
    '''
    entry_dict = hit.to_dict()
    for key, value in _entry_metadata_defaults.items():
        if key not in entry_dict:
            if required is not None:
                if required.exclude and key in required.exclude:
                    continue
                if required.include and key not in required.include:
                    continue

            entry_dict[key] = value

Markus Scheidgen's avatar
Markus Scheidgen committed
69
70
71
72
73
74
75
76
77
78
    for author_quantity in _all_author_quantities:
        authors = entry_dict.get(author_quantity)
        if authors is None:
            continue
        if isinstance(authors, dict):
            authors = [authors]
        for author in authors:
            if 'email' in author:
                del(author['email'])

79
80
81
    return entry_dict


82
83
84
85
path_analyzer = analyzer(
    'path_analyzer',
    tokenizer=tokenizer('path_tokenizer', 'pattern', pattern='/'))

86
87

class AlreadyExists(Exception): pass
88
89


90
91
92
class ElasticSearchError(Exception): pass


93
94
95
class AuthenticationRequiredError(Exception): pass


Markus Scheidgen's avatar
Markus Scheidgen committed
96
97
98
class ScrollIdNotFound(Exception): pass


99
100
101
class InvalidQuery(Exception): pass


102
entry_document = datamodel.EntryMetadata.m_def.a_elastic.document
103
material_document = Material.m_def.a_elastic.document
104
105
106

for domain in datamodel.domains:
    order_default_quantities.setdefault(domain, order_default_quantities.get('__all__'))
107
108


109
def delete_upload(upload_id):
110
    ''' Delete all entries with given ``upload_id`` from the index. '''
111
    index = entry_document._default_index()
112
113
114
    Search(index=index).query('match', upload_id=upload_id).delete()


115
def delete_entry(calc_id):
116
    ''' Delete the entry with the given ``calc_id`` from the index. '''
117
    index = entry_document._default_index()
118
119
120
    Search(index=index).query('match', calc_id=calc_id).delete()


121
122
def publish(calcs: Iterable[datamodel.EntryMetadata]) -> None:
    ''' Update all given calcs with their metadata and set ``publish = True``. '''
123
124
    def elastic_updates():
        for calc in calcs:
125
            entry = calc.a_elastic.create_index_entry()
126
            entry.published = True
127
128
129
130
131
            entry = entry.to_dict(include_meta=True)
            source = entry.pop('_source')
            entry['doc'] = source
            entry['_op_type'] = 'update'
            yield entry
132
133

    elasticsearch.helpers.bulk(infrastructure.elastic_client, elastic_updates())
134
135
136
    refresh()


137
138
def index_all(calcs: Iterable[datamodel.EntryMetadata], do_refresh=True) -> None:
    '''
139
140
141
142
    Adds all given calcs with their metadata to the index.

    Returns:
        Number of failed entries.
143
    '''
144
145
    def elastic_updates():
        for calc in calcs:
146
147
148
149
150
151
152
153
154
155
156
            try:
                entry = calc.a_elastic.create_index_entry()
                entry = entry.to_dict(include_meta=True)
                entry['_op_type'] = 'index'

                yield entry

            except Exception as e:
                utils.get_logger(__name__).error(
                    'could not create index doc', exc_info=e,
                    upload_id=calc.upload_id, calc_id=calc.calc_id)
157

158
    _, failed = elasticsearch.helpers.bulk(infrastructure.elastic_client, elastic_updates(), stats_only=True)
Markus Scheidgen's avatar
Markus Scheidgen committed
159

Markus Scheidgen's avatar
Markus Scheidgen committed
160
    if do_refresh:
Markus Scheidgen's avatar
Markus Scheidgen committed
161
162
        refresh()

163
    return failed
164
165


166
167
def refresh():
    infrastructure.elastic_client.indices.refresh(config.elastic.index_name)
168
169


170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
def _owner_es_query(owner: str, user_id: str = None):
    if owner == 'all':
        q = Q('term', published=True)
        if user_id is not None:
            q = q | Q('term', owners__user_id=user_id)
    elif owner == 'public':
        q = Q('term', published=True) & Q('term', with_embargo=False)
    elif owner == 'visible':
        q = Q('term', published=True) & Q('term', with_embargo=False)
        if user_id is not None:
            q = q | Q('term', owners__user_id=user_id)
    elif owner == 'shared':
        if user_id is None:
            raise AuthenticationRequiredError('Authentication required for owner value shared.')

        q = Q('term', owners__user_id=user_id)
    elif owner == 'user':
        if user_id is None:
            raise AuthenticationRequiredError('Authentication required for owner value user.')

        q = Q('term', uploader__user_id=user_id)
    elif owner == 'staging':
        if user_id is None:
            raise AuthenticationRequiredError('Authentication required for owner value user')
        q = Q('term', published=False) & Q('term', owners__user_id=user_id)
    elif owner == 'admin':
        if user_id is None or not datamodel.User.get(user_id=user_id).is_admin:
            raise AuthenticationRequiredError('This can only be used by the admin user.')
        q = None
    else:
        raise KeyError('Unsupported owner value')

    if q is not None:
        return q
    return Q()


207
class SearchRequest:
208
    '''
209
210
    Represents a search request and allows to execute that request.
    It allows to compose the following features: a query;
211
212
213
    statistics (metrics and aggregations); quantity values; scrolling, pagination for entries;
    scrolling for quantity values.

214
215
216
    The query part filters NOMAD data before the other features come into effect. There
    are specialized methods for configuring the :func:`owner` and :func:`time_range` queries.
    Quantity's can be search for by setting them as attributes.
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235

    The aggregations for statistics can be requested for pre-configured quantities. These
    bucket aggregations come with a metric calculated for each each possible
    quantity value.

    The other possible form of aggregations, allows to get quantity values as results
    (e.g. get all datasets, get all users, etc.). Each value can be accompanied by metrics
    (over all entries with that value) and an example value.

    Of course, searches can return a set of search results. Search objects can be
    configured with pagination or scrolling for these results. Pagination is the default
    and also allows ordering of results. Scrolling can be used if all entries need to be
    'scrolled through'. This might be necessary, since elastic search has limits on
    possible pages (e.g. 'from' must by smaller than 10000). On the downside, there is no
    ordering on scrolling.

    There is also scrolling for quantities to go through all quantity values. There is no
    paging for aggregations.
    '''
236
    def __init__(self, domain: str = config.meta.default_domain, query=None):
237
        self._domain = domain
238
239
        self._query = query
        self._search = Search(index=config.elastic.index_name)
240
        self._required = None
241

242
    def domain(self, domain: str = None):
243
        '''
244
245
        Applies the domain of this request to the query. Allows to optionally update
        the domain of this request.
246
        '''
247
248
249
250
251
252
        if domain is not None:
            self._domain = domain

        self.q = self.q & Q('term', domain=self._domain)
        return self

253
    def owner(self, owner_type: str = 'all', user_id: str = None):
254
        '''
255
256
        Uses the query part of the search to restrict the results based on the owner.
        The possible types are: ``all`` for all calculations; ``public`` for
257
258
259
260
        calculations visible by everyone, excluding embargo-ed entries and entries only visible
        to the given user; ``visible`` all data that is visible by the user, excluding
        embargo-ed entries from other users; ``user`` for all calculations of to the given
        user; ``staging`` for all calculations in staging of the given user.
261
262
263
264

        Arguments:
            owner_type: The type of the owner query, see above.
            user_id: The 'owner' given as the user's unique id.
265
266
267
268
269

        Raises:
            KeyError: If the given owner_type is not supported
            ValueError: If the owner_type requires a user but none is given, or the
                given user is not allowed to use the given owner_type.
270
        '''
271
        self.q &= _owner_es_query(owner=owner_type, user_id=user_id)
272
        return self
273

274
    def search_parameters(self, **kwargs):
275
        '''
276
277
        Configures the existing query with additional search parameters. Kwargs are
        interpreted as key value pairs. Keys have to coresspond to valid entry quantities
278
279
        in the domain's (DFT calculations) datamodel. Alternatively search parameters
        can be set via attributes.
280
        '''
281
282
        for name, value in kwargs.items():
            self.search_parameter(name, value)
Markus Scheidgen's avatar
Markus Scheidgen committed
283

284
        return self
Markus Scheidgen's avatar
Markus Scheidgen committed
285

286
    def _search_parameter_to_es(self, name, value):
287
        quantity = search_quantities[name]
288

289
        if quantity.many and not isinstance(value, list):
290
291
            value = [value]

292
        if quantity.many_or and isinstance(value, List):
293
            return Q('terms', **{quantity.search_field: value})
294

295
296
        if quantity.derived:
            if quantity.many and not isinstance(value, list):
297
                value = [value]
298
            value = quantity.derived(value)
299

300
        if isinstance(value, list):
301
            values = value
302
        else:
303
            values = [value]
304

305
306
307
        return Q('bool', must=[
            Q('match', **{quantity.search_field: item})
            for item in values])
308

309
310
    def search_parameter(self, name, value):
        self.q &= self._search_parameter_to_es(name, value)
311
        return self
312

313
    def query(self, query):
314
        ''' Adds the given query as a 'and' (i.e. 'must') clause to the request. '''
315
        self._query &= query
316

317
        return self
318

319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
    def query_expression(self, expression) -> 'SearchRequest':

        bool_operators = {'$and': 'must', '$or': 'should', '$not': 'must_not'}
        comp_operators = {'$%s' % op: op for op in ['gt', 'gte', 'lt', 'lte']}

        def _to_es(key, values):
            if key in bool_operators:
                if isinstance(values, dict):
                    values = [values]
                assert isinstance(values, list), 'bool operator requires a list of dicts or dict'
                child_es_queries = [
                    _to_es(child_key, child_value)
                    for child_query in values
                    for child_key, child_value in child_query.items()]
                return Q('bool', **{bool_operators[key]: child_es_queries})

            if key in comp_operators:
                assert isinstance(values, dict), 'comparison operator requires a dict'
                assert len(values) == 1, 'comparison operator requires exactly one quantity'
                quantity_name, value = next(iter(values.items()))
                quantity = search_quantities.get(quantity_name)
                assert quantity is not None, 'quantity %s does not exist' % quantity_name
                return Q('range', **{quantity.search_field: {comp_operators[key]: value}})

            try:
                return self._search_parameter_to_es(key, values)
            except KeyError:
                assert False, 'quantity %s does not exist' % key

        if len(expression) == 0:
            self.q &= Q()
        else:
            self.q &= Q('bool', must=[_to_es(key, value) for key, value in expression.items()])

        return self
Alvin Noe Ladines's avatar
Alvin Noe Ladines committed
354

355
    def time_range(self, start: datetime, end: datetime):
356
        ''' Adds a time range to the query. '''
357
358
        if start is None and end is None:
            return self
359

360
361
362
363
        if start is None:
            start = datetime.fromtimestamp(0)
        if end is None:
            end = datetime.utcnow()
364

365
        self.q &= Q('range', upload_time=dict(gte=start, lte=end))
366

367
        return self
368

369
370
    @property
    def q(self):
371
        ''' The underlying elasticsearch_dsl query object '''
372
373
        if self._query is None:
            return Q('match_all')
374
375
        else:
            return self._query
376

377
    @q.setter
378
    def q(self, q):
379
        self._query = q
380

381
    def totals(self, metrics_to_use: List[str] = []):
382
        '''
383
        Configure the request to return overall totals for the given metrics.
384

385
386
387
        The statics are returned with the other quantity statistics under the pseudo
        quantity name 'total'. 'total' contains the pseudo value 'all'. It is used to
        store the metrics aggregated over all entries in the search results.
388
        '''
389
390
        self._add_metrics(self._search.aggs, metrics_to_use)
        return self
391

392
    def statistics(self, statistics: List[str], metrics_to_use: List[str] = []):
393
        '''
394
        Configures the domain's default statistics.
395
        '''
396
397
        for statistic in statistics:
            search_quantity = search_quantities[statistic]
398
            statistic_order = search_quantity.statistic_order
399
            self.statistic(
400
401
                search_quantity.qualified_name,
                search_quantity.statistic_size,
402
                metrics_to_use=metrics_to_use,
403
                order={statistic_order: 'asc' if statistic_order == '_key' else 'desc'})
404

405
        return self
406

407
408
    def statistic(
            self, quantity_name: str, size: int, metrics_to_use: List[str] = [],
409
            order: Dict[str, str] = dict(_key='asc'), include: str = None):
410
        '''
411
412
413
414
415
416
417
418
419
420
421
422
423
        This can be used to display statistics over the searched entries and allows to
        implement faceted search on the top values for each quantity.

        The metrics contain overall and per quantity value sums of code runs (calcs),
        unique code runs, datasets, and additional domain specific metrics
        (e.g. total energies, and unique geometries for DFTcalculations). The quantities
        that can be aggregated to metrics are defined in module:`datamodel`. Aggregations
        and respective metrics are calculated for aggregations given in ``aggregations``
        and metrics in ``aggregation_metrics``. As a pseudo aggregation ``total_metrics``
        are calculation over all search results. The ``aggregations`` gives tuples of
        quantities and default aggregation sizes.

        The search results will contain a dictionary ``statistics``. This has a key
424
        for each configured quantity. Each quantity key will hold a dict
425
426
        with a key for each quantity value. Each quantity value key will hold a dict
        with a key for each metric. The values will be the actual aggregated metric values.
427

428
        Arguments:
429
            quantity_name: The quantity to aggregate statistics for. Only works on *keyword* field.
430
431
432
            metrics_to_use: The metrics calculated over the aggregations. Can be
                ``unique_code_runs``, ``datasets``, other domain specific metrics.
                The basic doc_count metric ``code_runs`` is always given.
433
            order: The order dictionary is passed to the elastic search aggregation.
434
435
436
            include:
                Uses an regular expression in ES to only return values that include
                the given substring.
437
438
        '''
        quantity = search_quantities[quantity_name]
439
440
441
442
        terms_kwargs = {}
        if include is not None:
            terms_kwargs['include'] = '.*%s.*' % include
        terms = A('terms', field=quantity.search_field, size=size, order=order, **terms_kwargs)
443

444
445
        buckets = self._search.aggs.bucket('statistics:%s' % quantity_name, terms)
        self._add_metrics(buckets, metrics_to_use)
446

447
        return self
448

449
450
451
    def _add_metrics(self, parent=None, metrics_to_use: List[str] = []):
        if parent is None:
            parent = self._search.aggs
452

453
        for metric in metrics_to_use:
454
            metric_quantity = metrics[metric]
455
            field = metric_quantity.search_field
456
457
458
            parent.metric(
                'metric:%s' % metric_quantity.metric_name,
                A(metric_quantity.metric, field=field))
459

460
    def date_histogram(self, metrics_to_use: List[str] = [], interval: str = '1M'):
461
        '''
462
        Adds a date histogram on the given metrics to the statistics part.
463
        '''
464
        histogram = A('date_histogram', field='upload_time', interval=interval, format='yyyy-MM-dd')
465
        self._add_metrics(self._search.aggs.bucket('statistics:date_histogram', histogram), metrics_to_use)
466

467
        return self
468

469
    def quantities(self, **kwargs):
470
        '''
471
472
        Shorthand for adding multiple quantities. See :func:`quantity`. Keywork argument
        keys are quantity name, values are tuples of size and after value.
473
        '''
474
475
476
        for name, spec in kwargs:
            size, after = spec
            self.quantity(name, after=after, size=size)
477

478
        return self
479

480
481
482
    def quantity(
            self, name, size=100, after=None, examples=0, examples_source=None,
            order_by: str = None, order: str = 'desc'):
483
        '''
484
        Adds a requests for values of the given quantity.
485
486
487
        It allows to scroll through all values via elasticsearch's
        composite aggregations. The response will contain the quantity values and
        an example entry for each value.
488

489
490
        This can be used to implement continues scrolling through authors, datasets,
        or uploads within the searched entries.
491

492
493
        If one or more quantities are specified,
        the search results will contain a dictionary ``quantities``. The keys are quantity
494
495
        name the values dictionary with 'after' and 'values' key.
        The 'values' key holds a dict with all the values as keys and their entry count
496
497
        as values (i.e. number of entries with that value).

498
        Arguments:
499
            name: The quantity name. Must be in :data:`quantities`.
500
501
502
503
504
505
            after: The 'after' value allows to scroll over various requests, by providing
                the 'after' value of the last search. The 'after' value is part of the
                response. Use ``None`` in the first request.
            size:
                The size gives the ammount of maximum values in the next scroll window.
                If the size is None, a maximum of 100 quantity values will be requested.
506
507
508
            examples:
                Number of results to return that has each value
            order_by:
509
                A sortable quantity that should be used to order. By default, the max of each
510
511
512
                value bucket is used.
            order:
                "desc" or "asc"
513
        '''
514
515
516
        if size is None:
            size = 100

517
        quantity = search_quantities[name]
518
        terms = A('terms', field=quantity.search_field)
519

520
521
522
        # We are using elastic searchs 'composite aggregations' here. We do not really
        # compose aggregations, but only those pseudo composites allow us to use the
        # 'after' feature that allows to scan through all aggregation values.
523
524
525
        if order_by is None:
            composite = dict(sources={name: terms}, size=size)
        else:
Markus Scheidgen's avatar
Markus Scheidgen committed
526
            sort_terms = A('terms', field=order_by, order=order)
527
            composite = dict(sources=[{order_by: sort_terms}, {name: terms}], size=size)
528
        if after is not None:
529
530
531
532
533
534
            if order_by is None:
                composite['after'] = {name: after}
            else:
                composite['after'] = {order_by: after, name: ''}

        composite_agg = self._search.aggs.bucket('quantity:%s' % name, 'composite', **composite)
535

536
        if examples > 0:
537
            kwargs: Dict[str, Any] = {}
538
539
            if examples_source is not None:
                kwargs.update(_source=dict(includes=examples_source))
540

541
            composite_agg.metric('examples', A('top_hits', size=examples, **kwargs))
542

543
        return self
544

545
546
547
548
549
    def global_statistics(self):
        '''
        Adds general statistics to the request. The results will have a key called
        global_statistics.
        '''
Markus Scheidgen's avatar
Markus Scheidgen committed
550
        self.owner('public')
551
552
553
        self._search.aggs.metric(
            'global_statistics:n_entries', A('value_count', field='calc_id'))
        self._search.aggs.metric(
Markus Scheidgen's avatar
Markus Scheidgen committed
554
            'global_statistics:n_uploads', A('cardinality', field='upload_id'))
555
556
557
558
        self._search.aggs.metric(
            'global_statistics:n_calculations', A('sum', field='dft.n_calculations'))
        self._search.aggs.metric(
            'global_statistics:n_quantities', A('sum', field='dft.n_quantities'))
Markus Scheidgen's avatar
About    
Markus Scheidgen committed
559
560
        self._search.aggs.metric(
            'global_statistics:n_materials', A('cardinality', field='encyclopedia.material.material_id'))
561
562
563

        return self

564
    def exclude(self, *args):
565
        ''' Exclude certain elastic fields from the search results. '''
566
        self._search = self._search.source(excludes=args)
567
        self._required = MetadataRequired(exclude=args)
568
569
        return self

570
    def include(self, *args):
571
        ''' Include only the given fields in the search results. '''
572
573
574
        self._search = self._search.source(includes=args)
        return self

575
    def execute(self):
576
        '''
577
        Executes without returning actual results. Only makes sense if the request
578
        was configured for statistics or quantity values.
579
        '''
580
581
582
        search = self._search.query(self.q)[0:0]
        response = search.execute()
        return self._response(response)
583

584
    def execute_scan(self, order_by: str = None, order: int = -1, **kwargs):
585
        '''
586
587
        This execute the search as scan. The result will be a generator over the found
        entries. Everything but the query part of this object, will be ignored.
588
        '''
589
590
591
        search = self._search.query(self.q)

        if order_by is not None:
592
            order_by_quantity = search_quantities[order_by]
593
594

            if order == 1:
595
                search = search.sort(order_by_quantity.search_field)
596
            else:
597
                search = search.sort('-%s' % order_by_quantity.search_field)  # pylint: disable=no-member
598

599
600
            search = search.params(preserve_order=True)

601
        for hit in search.params(**kwargs).scan():
602
            yield _es_to_entry_dict(hit, self._required)
603

604
    def execute_paginated(
605
606
607
            self,
            page: int = 1, per_page=10, page_offset: int = None,
            order_by: str = None, order: int = -1):
608
        '''
609
610
611
612
613
        Executes the search and returns paginated results. Those are sorted.

        Arguments:
            page: The requested page, starts with 1.
            per_page: The number of entries per page.
614
            page_offset: Instead of a page number, use this absolute offset.
615
616
            order_by: The quantity to order by.
            order: -1 or 1 for descending or ascending order.
617
        '''
618
        if order_by is None:
619
620
621
            order_by_quantity = order_default_quantities[self._domain]
        else:
            order_by_quantity = search_quantities[order_by]
622

623
        search = self._search.query(self.q)
624
625

        if order == 1:
626
            search = search.sort(order_by_quantity.search_field)
627
        else:
628
            search = search.sort('-%s' % order_by_quantity.search_field)  # pylint: disable=no-member
629
630
631
632
633

        if page_offset is not None:
            search = search[page_offset: page_offset + per_page]  # pylint: disable=unsubscriptable-object
        else:
            search = search[(page - 1) * per_page: page * per_page]  # pylint: disable=unsubscriptable-object
634

Markus Scheidgen's avatar
Markus Scheidgen committed
635
        es_result = search.execute()
636

Markus Scheidgen's avatar
Markus Scheidgen committed
637
638
        result = self._response(es_result, with_hits=True)

639
640
641
642
643
644
645
        if page_offset is not None:
            result.update(pagination=dict(
                total=result['total'],
                page_offset=page_offset, per_page=per_page))
        else:
            result.update(
                pagination=dict(total=result['total'], page=page, per_page=per_page))
646
647
        return result

648
649
650
    def execute_scrolled(
            self, scroll_id: str = None, size: int = 1000, scroll: str = u'5m',
            order_by: str = None, order: int = -1):
651
        '''
652
653
        Executes a scrolling search. based on ES scroll API. Pagination is replaced with
        scrolling, no ordering is available, no statistics, no quantities will be provided.
654

655
656
657
658
        Scrolling is done by calling this function again and again with the same ``scroll_id``.
        Each time, this function will return the next batch of search results. If the
        ``scroll_id`` is not available anymore, a new ``scroll_id`` is assigned and scrolling
        starts from the beginning again.
659

660
661
        The response will contain a 'scroll' part with attributes 'total', 'scroll_id',
        and 'size'.
662

663
664
665
666
667
668
        Arguments:
            scroll_id: The scroll id to receive the next batch from. None will create a new
                scroll.
            size: The batch size in number of hits.
            scroll: The time the scroll should be kept alive (i.e. the time between requests
                to this method) in ES time units. Default is 5 minutes.
669
670

        TODO support order and order_by
671
        '''
672
        es = infrastructure.elastic_client
673

674
675
676
        if scroll_id is None:
            # initiate scroll
            resp = es.search(  # pylint: disable=E1123
677
                body=self._search.query(self.q).to_dict(), scroll=scroll, size=size,
678
                index=config.elastic.index_name)
679

680
681
682
683
            scroll_id = resp.get('_scroll_id')
            if scroll_id is None:
                # no results for search query
                return dict(scroll=dict(total=0, size=size), results=[])
684

685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
        else:
            try:
                resp = es.scroll(scroll_id, scroll=scroll)  # pylint: disable=E1123
            except NotFoundError:
                raise ScrollIdNotFound()

        total = resp['hits']['total']
        results = list(hit['_source'] for hit in resp['hits']['hits'])

        # since we are using the low level api here, we should check errors
        if resp["_shards"]["successful"] < resp["_shards"]["total"]:
            utils.get_logger(__name__).error('es operation was unsuccessful on at least one shard')
            raise ElasticSearchError('es operation was unsuccessful on at least one shard')

        if len(results) == 0:
            es.clear_scroll(body={'scroll_id': [scroll_id]}, ignore=(404, ))  # pylint: disable=E1123
            scroll_id = None

703
        scroll_info = dict(total=total, size=size, scroll=True)
704
705
706
707
708
        if scroll_id is not None:
            scroll_info.update(scroll_id=scroll_id)

        return dict(scroll=scroll_info, results=results)

709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
    def execute_aggregated(
            self, after: str = None, per_page: int = 1000, includes: List[str] = None):
        '''
        Uses a composite aggregation on top of the search to go through the result
        set. This allows to go arbirarely deep without using scroll. But, it will
        only return results with ``upload_id``, ``calc_id`` and the given
        quantities. The results will be 'ordered' by ``upload_id``.

        Arguments:
            after: The key that determines the start of the current page. This after
                key is returned with each response. Use None (default) for the first
                request.
            per_page: The size of each page.
            includes: A list of quantity names that should be returned in addition to
                ``upload_id`` and ``calc_id``.
        '''
        upload_id_agg = A('terms', field="upload_id")
        calc_id_agg = A('terms', field="calc_id")

        composite = dict(
            sources=[dict(upload_id=upload_id_agg), dict(calc_id=calc_id_agg)],
            size=per_page)

        if after is not None:
            upload_id, calc_id = after.split(':')
            composite['after'] = dict(upload_id=upload_id, calc_id=calc_id)

        composite_agg = self._search.aggs.bucket('ids', 'composite', **composite)
        if includes is not None:
            composite_agg.metric('examples', A('top_hits', size=1, _source=dict(includes=includes)))

        search = self._search.query(self.q)[0:0]
        response = search.execute()

        ids = response['aggregations']['ids']
        if 'after_key' in ids:
            after_dict = ids['after_key']
            after = '%s:%s' % (after_dict['upload_id'], after_dict['calc_id'])
        else:
            after = None

        id_agg_info = dict(total=response['hits']['total'], after=after, per_page=per_page)

        def transform_result(es_result):
            result = dict(
                upload_id=es_result['key']['upload_id'],
                calc_id=es_result['key']['calc_id'])

            if includes is not None:
                source = es_result['examples']['hits']['hits'][0]['_source']
                for key in source:
                    result[key] = source[key]

            return result

        results = [
            transform_result(item) for item in ids['buckets']]

        return dict(aggregation=id_agg_info, results=results)

769
    def _response(self, response, with_hits: bool = False) -> Dict[str, Any]:
770
        '''
771
        Prepares a response object covering the total number of results, hits, statistics,
772
773
        and quantities. Other aspects like pagination and scrolling have to be added
        elsewhere.
774
        '''
775
        result: Dict[str, Any] = dict()
776
        aggs = response.aggregations.to_dict()
777

778
        # total
779
780
781
782
        total = response.hits.total if hasattr(response, 'hits') else 0
        result.update(total=total)

        # hits
783
        if len(response.hits) > 0 or with_hits:
784
            result.update(results=[_es_to_entry_dict(hit, self._required) for hit in response.hits])
785
786
787

        # statistics
        def get_metrics(bucket, code_runs):
788
            result = {}
789
            # TODO optimize ... go through the buckets not the metrics
790
            for metric in metrics:
791
792
793
                agg_name = 'metric:%s' % metric
                if agg_name in bucket:
                    result[metric] = bucket[agg_name]['value']
794
            result.update(code_runs=code_runs)
795
796
            return result

797
        statistics_results = {
798
            quantity_name[11:]: {
799
                str(bucket['key']): get_metrics(bucket, bucket['doc_count'])
800
                for bucket in quantity['buckets']
801
            }
802
            for quantity_name, quantity in aggs.items()
803
804
            if quantity_name.startswith('statistics:')
        }
805

806
807
808
809
810
811
812
813
814
        # global statistics
        global_statistics_results = {
            agg_name[18:]: agg.get('value')
            for agg_name, agg in aggs.items()
            if agg_name.startswith('global_statistics:')
        }
        if len(global_statistics_results) > 0:
            result.update(global_statistics=global_statistics_results)

815
816
817
        # totals
        totals_result = get_metrics(aggs, total)
        statistics_results['total'] = dict(all=totals_result)
Markus Scheidgen's avatar
Markus Scheidgen committed
818

819
820
        if len(statistics_results) > 0:
            result.update(statistics=statistics_results)
Markus Scheidgen's avatar
Markus Scheidgen committed
821

822
        # quantities
823
        def create_quantity_result(quantity_name, quantity):
824
825
826
827
828
829
830
            values = {}
            for bucket in quantity['buckets']:
                value = dict(
                    total=bucket['doc_count'])
                if 'examples' in bucket:
                    examples = [hit['_source'] for hit in bucket['examples']['hits']['hits']]
                    value.update(examples=examples)
Markus Scheidgen's avatar
Markus Scheidgen committed
831

832
                values[bucket['key'][quantity_name]] = value
833

834
            result = dict(values=values)
835
            if 'after_key' in quantity:
836
837
838
839
840
841
842
843
                after = quantity['after_key']
                if len(after) == 1:
                    result.update(after=after[quantity_name])
                else:
                    for key in after:
                        if key != quantity_name:
                            result.update(after=after[key])
                            break
844

845
            return result
846

847
        quantity_results = {
848
849
850
            quantity_name[9:]: create_quantity_result(quantity_name[9:], quantity)
            for quantity_name, quantity in aggs.items()
            if quantity_name.startswith('quantity:')
Markus Scheidgen's avatar
Markus Scheidgen committed
851
852
        }

853
854
        if len(quantity_results) > 0:
            result.update(quantities=quantity_results)
855
856

        return result
857

858
859
    def __str__(self):
        return json.dumps(self._search.to_dict(), indent=2)
860
861


862
def flat(obj, prefix=None):
863
    '''
864
865
    Helper that translates nested result objects into flattened dicts with
    ``domain.quantity`` as keys.
866
    '''
867
868
869
870
    if isinstance(obj, dict):
        result = {}
        for key, value in obj.items():
            if isinstance(value, dict):
871
                value = flat(value)
872
                for child_key, child_value in value.items():
873
                    result['%s.%s' % (key, child_key)] = child_value
874
875
876
877
878
879
880

            else:
                result[key] = value

        return result
    else:
        return obj
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000


def _api_to_es_query(query: api_models.Query) -> Q:
    '''
    Creates an ES query based on the API's query model. This needs to be a normalized
    query expression with explicit objects for logical, set, and comparison operators.
    Shorthand notations ala ``quantity:operator`` are not supported here; this
    needs to be resolved via the respective pydantic validator. There is also no
    validation of quantities and types.
    '''
    def quantity_to_es(name: str, value: api_models.Value) -> Q:
        # TODO depends on keyword or not, value might need normalization, etc.
        quantity = search_quantities[name]
        return Q('match', **{quantity.search_field: value})

    def parameter_to_es(name: str, value: api_models.QueryParameterValue) -> Q:

        if isinstance(value, api_models.All):
            return Q('bool', must=[
                quantity_to_es(name, item)
                for item in value.op])

        if isinstance(value, api_models.Any_):
            return Q('bool', should=[
                quantity_to_es(name, item)
                for item in value.op])

        if isinstance(value, api_models.None_):
            return Q('bool', must_not=[
                quantity_to_es(name, item)
                for item in value.op])

        if isinstance(value, api_models.ComparisonOperator):
            quantity = search_quantities[name]
            return Q('range', **{quantity.search_field: {
                type(value).__name__.lower(): value.op}})

        # list of values is treated as an "all" over the items
        if isinstance(value, list):
            return Q('bool', must=[
                quantity_to_es(name, item)
                for item in value])

        return quantity_to_es(name, value)

    def query_to_es(query: api_models.Query) -> Q:
        if isinstance(query, api_models.LogicalOperator):
            if isinstance(query, api_models.And):
                return Q('bool', must=[query_to_es(operand) for operand in query.op])

            if isinstance(query, api_models.Or):
                return Q('bool', should=[query_to_es(operand) for operand in query.op])

            if isinstance(query, api_models.Not):
                return Q('bool', must_not=query_to_es(query.op))

            raise NotImplementedError()

        if not isinstance(query, dict):
            raise NotImplementedError()

        # dictionary is like an "and" of all items in the dict
        if len(query) == 0:
            return Q()

        if len(query) == 1:
            key = next(iter(query))
            return parameter_to_es(key, query[key])

        return Q('bool', must=[
            parameter_to_es(name, value) for name, value in query.items()])

    return query_to_es(query)


def _api_to_es_statistic(es_search: Search, name: str, statistic: Statistic) -> A:
    '''
    Creates an ES aggregation based on the API's statistic model.
    '''

    quantity = search_quantities[statistic.quantity.value]
    if quantity.statistic_values is not None:
        statistic.size = len(quantity.statistic_values)

    terms_kwargs = {}
    if statistic.value_filter is not None:
        terms_kwargs['include'] = '.*%s.*' % statistic.value_filter

    order_type = '_count' if statistic.order.type_ == AggregationOrderType.entries else '_key'
    statistic_agg = es_search.aggs.bucket('statistic:%s' % name, A(
        'terms',
        field=quantity.search_field,
        size=statistic.size,
        order={order_type: statistic.order.direction.value},
        **terms_kwargs))

    for metric in statistic.metrics:
        metric_quantity = metrics[metric.value]
        statistic_agg.metric('metric:%s' % metric_quantity.metric_name, A(
            metric_quantity.metric,
            field=metric_quantity.search_field))


def _es_to_api_statistics(es_response, name: str, statistic: Statistic) -> StatisticResponse:
    '''
    Creates a StatisticResponse from elasticsearch response on a request executed with
    the given statistics.
    '''
    quantity = search_quantities[statistic.quantity.value]

    es_statistic = es_response.aggs['statistic:' + name]
    statistic_data = {}
    for bucket in es_statistic.buckets:
        value_data = dict(entries=bucket.doc_count)
        for metric in statistic.metrics:
            value_data[metric.value] = bucket['metric:' + metric.value].value
        statistic_data[bucket.key] = value_data

    if quantity.statistic_values is not None:
        for value in quantity.statistic_values: