common.py 18.4 KB
Newer Older
Markus Scheidgen's avatar
Markus Scheidgen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2018 Markus Scheidgen
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an"AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
'''
Markus Scheidgen's avatar
Markus Scheidgen committed
16
Common data, variables, decorators, models used throughout the API.
17
'''
18
from typing import Callable, IO, Set, Tuple, Iterable, Dict, Any
19
from flask_restplus import fields
Lauri Himanen's avatar
Lauri Himanen committed
20
from flask import request, make_response
21
import zipstream
22
from flask import stream_with_context, Response, g, abort
23
from urllib.parse import urlencode
24
25
import pprint
import io
26
import json
27

28
import sys
29
import os.path
30
31
import gzip
from functools import wraps
Markus Scheidgen's avatar
Markus Scheidgen committed
32

33
from nomad import search, config, datamodel
34
from nomad.app.optimade import filterparser
35
from nomad.app.common import RFC3339DateTime, rfc3339DateTime
36
from nomad.files import Restricted
37

Markus Scheidgen's avatar
Markus Scheidgen committed
38
from .api import api
Markus Scheidgen's avatar
Markus Scheidgen committed
39
40


41
42
43
44
45
46
if sys.version_info >= (3, 7):
    import zipfile
else:
    import zipfile37 as zipfile


47
48
49
50
51
52
53
54
metadata_model = api.model('MetaData', {
    'with_embargo': fields.Boolean(default=False, description='Data with embargo is only visible to the upload until the embargo period ended.'),
    'comment': fields.String(description='The comment are shown in the repository for each calculation.'),
    'references': fields.List(fields.String, descriptions='References allow to link calculations to external source, e.g. URLs.'),
    'coauthors': fields.List(fields.String, description='A list of co-authors given by user_id.'),
    'shared_with': fields.List(fields.String, description='A list of users to share calculations with given by user_id.'),
    '_upload_time': RFC3339DateTime(description='Overrride the upload time.'),
    '_uploader': fields.String(description='Override the uploader with the given user id.'),
55
    'datasets': fields.List(fields.String, description='A list of dataset ids.')
56
57
})

Markus Scheidgen's avatar
Markus Scheidgen committed
58
pagination_model = api.model('Pagination', {
59
60
    'total': fields.Integer(description='Number of total elements.'),
    'page': fields.Integer(description='Number of the current page, starting with 0.'),
61
62
63
    'per_page': fields.Integer(description='Number of elements per page.'),
    'order_by': fields.String(description='Sorting criterion.'),
    'order': fields.Integer(description='Sorting order -1 for descending, 1 for asceding.')
Markus Scheidgen's avatar
Markus Scheidgen committed
64
})
65
''' Model used in responses with pagination. '''
Markus Scheidgen's avatar
Markus Scheidgen committed
66

67
68
69
70
71
scroll_model = api.model('Scroll', {
    'scroll': fields.Boolean(default=False, description='Flag if scrolling is enables.'),
    'total': fields.Integer(default=0, description='The total amount of hits for the search.'),
    'scroll_id': fields.String(default=None, allow_null=True, description='The scroll_id that can be used to retrieve the next page.'),
    'size': fields.Integer(default=0, help='The size of the returned scroll page.')})
72
73
74
75
76
77
78
''' Model used in responses with scroll. '''

aggregation_model = api.model('Aggregation', {
    'after': fields.String(description='The after key for the current request.', allow_null=True),
    'total': fields.Integer(default=0, description='The total amount of hits for the search.'),
    'per_page': fields.Integer(default=0, help='The size of the requested page.', allow_null=True)})
''' Model used in responses with id aggregation. '''
79

Alvin Noe Ladines's avatar
Alvin Noe Ladines committed
80
81
82
83
84
85
86
87
88
89
90
91
query_model_fields = {
    qualified_name: quantity.flask_field
    for qualified_name, quantity in search.search_quantities.items()}

query_model_fields.update(**{
    'owner': fields.String(description='The group the calculations belong to.', allow_null=True, skip_none=True),
    'domain': fields.String(description='Specify the domain to search in: %s, default is ``%s``' % (
        ', '.join(['``%s``' % domain for domain in datamodel.domains]), config.meta.default_domain)),
    'from_time': fields.Raw(description='The minimum entry time.', allow_null=True, skip_none=True),
    'until_time': fields.Raw(description='The maximum entry time.', allow_null=True, skip_none=True)
})

Alvin Noe Ladines's avatar
Alvin Noe Ladines committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
query_model_fields.update(**{
    '$and': fields.List(fields.Raw, description=(
        'List of queries which must be present in search results.')),
    '$or': fields.List(fields.Raw, description=(
        'List of queries which should be present in search results.')),
    '$not': fields.List(fields.Raw, description=(
        'List of queries which must not be present in search results.')),
    '$lt': fields.Raw(description=(
        'Dict of quantiy name: value such that search results should have values '
        'less than value.')),
    '$lte': fields.Raw(description=(
        'Dict of quantiy name: value such that search results should have values '
        'less than or equal to value')),
    '$gt': fields.Raw(description=(
        'Dict of quantiy name: value such that search results should have values '
        'greater than value')),
    '$gte': fields.Raw(description=(
        'Dict of quantiy name: value such that search results should have values '
        'greater than or equal to value')),
})

query_model = api.model('Query', query_model_fields)
Alvin Noe Ladines's avatar
Alvin Noe Ladines committed
114

115
search_model_fields = {
Alvin Noe Ladines's avatar
Alvin Noe Ladines committed
116
    'query': fields.Nested(query_model, allow_null=True, skip_none=True),
117
118
    'pagination': fields.Nested(pagination_model, allow_null=True, skip_none=True),
    'scroll': fields.Nested(scroll_model, allow_null=True, skip_none=True),
119
    'aggregation': fields.Nested(aggregation_model, allow_null=True),
120
    'results': fields.List(fields.Raw(allow_null=True, skip_none=True), description=(
121
        'A list of search results. Each result is a dict with quantitie names as key and '
122
        'values as values'), allow_null=True, skip_none=True),
123
    'code': fields.Nested(api.model('Code', {
124
125
        'repo_url': fields.String(description=(
            'An encoded URL for the search query on the repo api.')),
126
127
128
129
130
131
132
        'python': fields.String(description=(
            'A piece of python code snippet which can be executed to reproduce the api result.')),
        'curl': fields.String(description=(
            'A curl command which can be executed to reproduce the api result.')),
        'clientlib': fields.String(description=(
            'A piece of python code which uses NOMAD\'s client library to access the archive.'))
    }), allow_null=True, skip_none=True)}
133

134
search_model = api.model('Search', search_model_fields)
135
136
137


def add_pagination_parameters(request_parser):
138
    ''' Add pagination parameters to Flask querystring parser. '''
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    request_parser.add_argument(
        'page', type=int, help='The page, starting with 1.', location='args')
    request_parser.add_argument(
        'per_page', type=int, help='Desired calcs per page.', location='args')
    request_parser.add_argument(
        'order_by', type=str, help='The field to sort by.', location='args')
    request_parser.add_argument(
        'order', type=int, help='Use -1 for decending and 1 for acending order.', location='args')


request_parser = api.parser()
add_pagination_parameters(request_parser)
pagination_request_parser = request_parser.copy()


def add_scroll_parameters(request_parser):
155
    ''' Add scroll parameters to Flask querystring parser. '''
156
157
158
159
160
161
162
    request_parser.add_argument(
        'scroll', type=bool, help='Enable scrolling')
    request_parser.add_argument(
        'scroll_id', type=str, help='The id of the current scrolling window to use.')


def add_search_parameters(request_parser):
163
    ''' Add search parameters to Flask querystring parser. '''
164
    # more search parameters
165
166
167
    request_parser.add_argument(
        'domain', type=str,
        help='Specify the domain to search in: %s, default is ``%s``' % (
168
            ', '.join(['``%s``' % domain for domain in datamodel.domains]),
169
            config.meta.default_domain))
170
171
    request_parser.add_argument(
        'owner', type=str,
172
        help='Specify which calcs to return: ``visible``, ``public``, ``all``, ``user``, ``staging``, default is ``visible``')
173
174
175
176
177
178
    request_parser.add_argument(
        'from_time', type=lambda x: rfc3339DateTime.parse(x),
        help='A yyyy-MM-ddTHH:mm:ss (RFC3339) minimum entry time (e.g. upload time)')
    request_parser.add_argument(
        'until_time', type=lambda x: rfc3339DateTime.parse(x),
        help='A yyyy-MM-ddTHH:mm:ss (RFC3339) maximum entry time (e.g. upload time)')
179
180
181
    request_parser.add_argument(
        'dft.optimade', type=str,
        help='A search query in the optimade filter language.')
182
183
184
    request_parser.add_argument(
        'query', type=str,
        help='A json serialized structured search query (as used in POST reuquests).')
185
186

    # main search parameters
187
    for qualified_name, quantity in search.search_quantities.items():
188
        request_parser.add_argument(
189
            qualified_name, help=quantity.description, action=quantity.argparse_action)
190
191


192
_search_quantities = set(search.search_quantities.keys())
193
194


195
def apply_search_parameters(search_request: search.SearchRequest, args: Dict[str, Any]):
196
    '''
197
    Help that adds query relevant request args to the given SearchRequest.
198
    '''
199
200
    args = {key: value for key, value in args.items() if value is not None}

201
202
203
204
205
    # domain
    domain = args.get('domain')
    if domain is not None:
        search_request.domain(domain=domain)

206
    # owner
207
    owner = args.get('owner', 'visible')
208
209
210
211
212
213
214
215
216
217
218
219
    try:
        search_request.owner(
            owner,
            g.user.user_id if g.user is not None else None)
    except ValueError as e:
        abort(401, getattr(e, 'message', 'Invalid owner parameter: %s' % owner))
    except Exception as e:
        abort(400, getattr(e, 'message', 'Invalid owner parameter'))

    # time range
    from_time_str = args.get('from_time', None)
    until_time_str = args.get('until_time', None)
Markus Scheidgen's avatar
Markus Scheidgen committed
220

221
222
223
224
225
    try:
        from_time = rfc3339DateTime.parse(from_time_str) if from_time_str is not None else None
        until_time = rfc3339DateTime.parse(until_time_str) if until_time_str is not None else None
        search_request.time_range(start=from_time, end=until_time)
    except Exception:
226
        abort(400, 'bad datetime format')
Markus Scheidgen's avatar
Markus Scheidgen committed
227

228
229
    # optimade
    try:
230
        optimade = args.get('dft.optimade', None)
231
        if optimade is not None:
232
233
            q = filterparser.parse_filter(
                optimade, nomad_properties=domain, without_prefix=True)
234
            search_request.query(q)
235
    except filterparser.FilterException as e:
236
237
238
239
240
241
242
243
244
245
246
247
248
249
        abort(400, 'Could not parse optimade query: %s' % str(e))

    # search expression
    query_str = args.get('query', None)
    if query_str is not None:
        try:
            query = json.loads(query_str)
        except Exception as e:
            abort(400, 'Could not JSON parse query expression: %s' % str(e))

        try:
            search_request.query_expression(query)
        except Exception as e:
            abort(400, 'Invalid query expression: %s' % str(e))
250
251
252
253

    # search parameter
    search_request.search_parameters(**{
        key: value for key, value in args.items()
254
        if key in _search_quantities})
Markus Scheidgen's avatar
Markus Scheidgen committed
255
256
257


def calc_route(ns, prefix: str = ''):
258
    ''' A resource decorator for /<upload>/<calc> based routes. '''
Markus Scheidgen's avatar
Markus Scheidgen committed
259
    def decorator(func):
260
        ns.route('%s/<string:upload_id>/<string:calc_id>' % prefix)(
Markus Scheidgen's avatar
Markus Scheidgen committed
261
            api.doc(params={
262
                'upload_id': 'The unique id for the requested upload.',
263
                'calc_id': 'The unique id for the requested calculation.'
Markus Scheidgen's avatar
Markus Scheidgen committed
264
265
266
            })(func)
        )
    return decorator
267
268
269


def upload_route(ns, prefix: str = ''):
270
    ''' A resource decorator for /<upload> based routes. '''
271
272
273
274
275
276
277
    def decorator(func):
        ns.route('%s/<string:upload_id>' % prefix)(
            api.doc(params={
                'upload_id': 'The unique id for the requested upload.'
            })(func)
        )
    return decorator
278
279
280
281


def streamed_zipfile(
        files: Iterable[Tuple[str, str, Callable[[str], IO], Callable[[str], int]]],
282
        zipfile_name: str, compress: bool = False):
283
    '''
284
285
286
287
288
289
290
291
292
293
    Creates a response that streams the given files as a streamed zip file. Ensures that
    each given file is only streamed once, based on its filename in the resulting zipfile.

    Arguments:
        files: An iterable of tuples with the filename to be used in the resulting zipfile,
            an file id within the upload, a callable that gives an binary IO object for the
            file id, and a callable that gives the file size for the file id.
        zipfile_name: A name that will be used in the content disposition attachment
            used as an HTTP respone.
        compress: Uses compression. Default is stored only.
294
    '''
295
296
297
298

    streamed_files: Set[str] = set()

    def generator():
299
        ''' Stream a zip file with all files using zipstream. '''
300
        def iterator():
301
            '''
302
303
            Replace the directory based iter of zipstream with an iter over all given
            files.
304
            '''
305
            # the actual contents
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
            for zipped_filename, file_id, open_io, file_size in files:
                if zipped_filename in streamed_files:
                    continue
                streamed_files.add(zipped_filename)

                # Write a file to the zipstream.
                try:
                    f = open_io(file_id)
                    try:
                        def iter_content():
                            while True:
                                data = f.read(1024 * 64)
                                if not data:
                                    break
                                yield data

                        yield dict(
                            arcname=zipped_filename, iterable=iter_content(),
                            buffer_size=file_size(file_id))
                    finally:
                        f.close()
                except KeyError:
                    # files that are not found, will not be returned
                    pass
                except Restricted:
                    # due to the streaming nature, we cannot raise 401 here
                    # we just leave it out in the download
                    pass

        compression = zipfile.ZIP_DEFLATED if compress else zipfile.ZIP_STORED
        zip_stream = zipstream.ZipFile(mode='w', compression=compression, allowZip64=True)
        zip_stream.paths_to_write = iterator()

        for chunk in zip_stream:
            yield chunk

    response = Response(stream_with_context(generator()), mimetype='application/zip')
    response.headers['Content-Disposition'] = 'attachment; filename={}'.format(zipfile_name)
    return response
345
346


347
def _query_api_url(*args, query: Dict[str, Any] = None):
348
    '''
349
350
351
352
    Creates a API URL.
    Arguments:
        *args: URL path segments after the API base URL
        query_string: A dict with query string parameters
353
    '''
354
    url = os.path.join(config.api_url(False), *args)
355
    if query is not None and len(query) > 0:
356
        url = '%s?%s' % (url, urlencode(query, doseq=True))
357

358
359
360
    return url


361
def query_api_python(query):
362
    '''
363
364
    Creates a string of python code to execute a search query to the repository using
    the requests library.
365
    '''
366
367
    query = _filter_api_query(query)
    url = _query_api_url('archive', 'query')
368
    return '''import requests
369
370
371
372
373
374
375
376
377
378
379
380
381
382
response = requests.post('{}', json={{
    'query': {{
{}
    }}
}})
data = response.json()'''.format(
        url,
        ',\n'.join([
            '        \'%s\': %s' % (key, pprint.pformat(value, compact=True))
            for key, value in query.items()])
    )


def _filter_api_query(query):
Markus Scheidgen's avatar
Markus Scheidgen committed
383
384
385
386
    def normalize_value(key, value):
        quantity = search.search_quantities.get(key)
        if quantity.many and not isinstance(value, list):
            return [value]
387
388
        elif isinstance(value, list) and len(value) == 1:
            return value[0]
Markus Scheidgen's avatar
Markus Scheidgen committed
389
390
391

        return value

392
393
    result = {
        key: normalize_value(key, value) for key, value in query.items()
394
        if key in search.search_quantities and (key != 'domain' or value != config.meta.default_domain)
395
396
    }

397
    for key in ['dft.optimade']:
398
399
400
401
402
403
        if key in query:
            result[key] = query[key]

    return result


404
405
406
407
408
409
410
411
412
def query_api_repo_url(query):
    '''
    Creates an encoded URL string access a search query on the repo api.
    '''
    query = dict(query)
    for to_delete in ['per_page', 'page', 'exclude']:
        if to_delete in query:
            del(query[to_delete])
    for key, value in dict(order_by=['upload_time'], order=['-1'], domain=['dft'], owner=['public']).items():
Markus Scheidgen's avatar
Markus Scheidgen committed
413
        if key in query and query[key] == value:
414
415
416
417
            del(query[key])
    return _query_api_url('repo', query=query)


418
419
420
421
422
423
def query_api_clientlib(**kwargs):
    '''
    Creates a string of python code to execute a search query on the archive using
    the client library.
    '''
    query = _filter_api_query(kwargs)
424

425
426
427
428
429
430
    out = io.StringIO()
    out.write('from nomad import client, config\n')
    out.write('config.client.url = \'%s\'\n' % config.api_url(ssl=False))
    out.write('results = client.query_archive(query={%s' % ('' if len(kwargs) == 0 else '\n'))
    out.write(',\n'.join([
        '    \'%s\': %s' % (key, pprint.pformat(value, compact=True))
431
        for key, value in query.items()]))
432
433
434
435
436
437
    out.write('})\n')
    out.write('print(results)\n')

    return out.getvalue()


438
def query_api_curl(query):
439
    '''
440
441
    Creates a string of curl command to execute a search query and download the respective
    archives in a .zip file.
442
    '''
443
444
    url = _query_api_url('archive', 'download', query=query)
    return 'curl "%s" --output nomad.zip' % url
445
446


447
def enable_gzip(level: int = 1, min_size: int = 1024):
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
    """
    Args:
        level: The gzip compression level from 1-9
        min_size: The minimum response size in bytes for which the compression
            will be enabled.
    """
    def inner(function):
        @wraps(function)
        def wrapper(*args, **kwargs):
            response = make_response(function(*args, **kwargs))
            if response.status_code == 200:
                accept_encoding = request.headers["Accept-Encoding"]
                content_length = int(response.headers["Content-Length"])
                if "gzip" in accept_encoding and content_length >= min_size:
                    data = response.data
                    data = gzip.compress(data, level)
                    response.data = data
                    response.headers['Content-Length'] = len(data)
                    response.headers["Content-Encoding"] = "gzip"
                    return response
            return response
        return wrapper
    return inner