test_search.py 6.66 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# Copyright The NOMAD Authors.
#
# This file is part of NOMAD. See https://nomad-lab.eu for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

19
from typing import List, Dict, Any, Union, Iterable
20
21
22
import pytest
import json

23
from nomad import config, utils, infrastructure
24
from nomad.app.v1.models import WithQuery
25
from nomad.search import quantity_values, search, update_by_query, refresh
26
from nomad.metainfo.elasticsearch_extension import entry_type, entry_index, material_index
27

28
from tests.utils import ExampleData
29
30


31
32
33
34
35
36
37
38
39
40
41
42
43
def assert_search_upload(
        entries: Union[int, Iterable] = -1,
        additional_keys: List[str] = [],
        upload_id: str = None,
        **kwargs):

    if isinstance(entries, list):
        size = len(entries)
    elif isinstance(entries, int):
        size = entries
    else:
        assert False

David Sikter's avatar
David Sikter committed
44
    keys = ['entry_id', 'upload_id', 'mainfile']
45
46
47
48
49
50
51
52
53
54
    refresh()
    body: Dict[str, Any] = {}
    body.update(size=10)
    if upload_id is not None:
        body['query'] = dict(match=dict(upload_id=upload_id))

    search_results = infrastructure.elastic_client.search(
        index=config.elastic.entries_index, body=body)['hits']

    if size != -1:
55
        assert search_results['total']['value'] == size
56

57
    if search_results['total']['value'] > 0:
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
        for hit in search_results['hits']:
            hit = utils.flat(hit['_source'])
            for key, value in kwargs.items():
                assert hit.get(key, None) == value, key

            if 'pid' in hit:
                assert int(hit.get('pid')) > 0

            for key in keys:
                assert key in hit, f'{key} is missing'

            for key in additional_keys:
                assert key in hit, f'{key} is missing'
                assert hit[key] != config.services.unavailable_value

73
            for coauthor in hit.get('entry_coauthors', []):
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
                assert coauthor.get('name', None) is not None


def test_mapping_compatibility(elastic_infra):
    from nomad.infrastructure import elastic_client

    v0 = elastic_client.indices.get(config.elastic.entries_index)
    v1 = elastic_client.indices.get(config.elastic.entries_index)

    def get_mapping(index):
        assert len(index) == 1
        index = index[next(iter(index))]
        assert len(index['mappings']) == 1
        return index['mappings'][next(iter(index['mappings']))]

    v0, v1 = get_mapping(v0), get_mapping(v1)

    def compare(a, b, path='', results=None):
        if results is None:
            results = []
        if path != '':
            path += '.'
        for key in set(list(a.keys()) + list(b.keys())):
            if key in a and key in b:
                next_a, next_b = a[key], b[key]
                if isinstance(next_a, dict) and isinstance(next_b, dict):
                    compare(next_a, next_b, f'{path}{key}', results=results)
                    continue

                if next_a == next_b:
                    continue

            results.append(f"{'v0' if key in a else 'v1'}:{path}{key}")

        return results

    for diff in compare(v0, v1):
        # assert that there are only top-level differences and mapping types and fields are
        # the same
        assert len([c for c in diff if c == '.']) == 1, diff


116
@pytest.fixture()
117
118
def example_data(elastic, test_user):
    data = ExampleData(main_author=test_user)
119
    data.create_upload(upload_id='test_upload_id', published=True, embargo_length=12)
120
121
122
    for i in range(0, 4):
        data.create_entry(
            upload_id='test_upload_id',
David Sikter's avatar
David Sikter committed
123
            entry_id=f'test_entry_id_{i}',
124
            mainfile='test_content/test_embargo_entry/mainfile.json')
125

126
    data.save(with_files=False, with_mongo=False)
127
128


129
130
def test_index(indices, example_data):
    assert material_index.get(id='test_material_id') is not None
131
    assert entry_index.get(id='test_entry_id_0') is not None
132
133


134
135
136
137
138
139
140
141
142
143
144
@pytest.fixture()
def indices(elastic):
    pass


def test_indices(indices):
    assert entry_type.quantities.get('entry_id') is not None
    assert entry_type.quantities.get('upload_id') is not None


@pytest.mark.parametrize('api_query, total', [
145
146
    pytest.param('{}', 4, id='empty'),
    pytest.param('{"results.method.simulation.program_name": "VASP"}', 4, id="match"),
147
148
    pytest.param('{"results.method.simulation.program_name": "VASP", "results.method.simulation.dft.xc_functional_type": "dne"}', 0, id="match_all"),
    pytest.param('{"and": [{"results.method.simulation.program_name": "VASP"}, {"results.method.simulation.dft.xc_functional_type": "dne"}]}', 0, id="and"),
149
    pytest.param('{"or":[{"results.method.simulation.program_name": "VASP"}, {"results.method.simulation.dft.xc_functional_type": "dne"}]}', 4, id="or"),
150
151
    pytest.param('{"not":{"results.method.simulation.program_name": "VASP"}}', 0, id="not"),
    pytest.param('{"results.method.simulation.program_name": {"all": ["VASP", "dne"]}}', 0, id="all"),
152
    pytest.param('{"results.method.simulation.program_name": {"any": ["VASP", "dne"]}}', 4, id="any"),
153
    pytest.param('{"results.method.simulation.program_name": {"none": ["VASP", "dne"]}}', 0, id="none"),
154
155
156
    pytest.param('{"results.method.simulation.program_name": {"gte": "VASP"}}', 4, id="gte"),
    pytest.param('{"results.method.simulation.program_name": {"gt": "A"}}', 4, id="gt"),
    pytest.param('{"results.method.simulation.program_name": {"lte": "VASP"}}', 4, id="lte"),
157
    pytest.param('{"results.method.simulation.program_name": {"lt": "A"}}', 0, id="lt"),
158
])
159
def test_search_query(indices, example_data, api_query, total):
160
161
162
163
164
    api_query = json.loads(api_query)
    results = search(owner='all', query=WithQuery(query=api_query).query)
    assert results.pagination.total == total  # pylint: disable=no-member


165
def test_update_by_query(indices, example_data):
166
    update_by_query(
167
        update_script='''
168
            ctx._source.entry_id = "other test id";
169
        ''',
170
        owner='all', query={}, index='v1')
171

172
    entry_index.refresh()
173

174
    results = search(owner='all', query=dict(entry_id='other test id'))
175
176
177
178
179
180
    assert results.pagination.total == 4


def test_quantity_values(indices, example_data):
    results = list(quantity_values('entry_id', page_size=1, owner='all'))
    assert results == ['test_entry_id_0', 'test_entry_id_1', 'test_entry_id_2', 'test_entry_id_3']