common.py 7.56 KB
Newer Older
Markus Scheidgen's avatar
Markus Scheidgen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2018 Markus Scheidgen
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an"AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Common data, variables, decorators, models used throughout the API.
"""
18
from typing import Callable, IO, Set, Tuple, Iterable
19
from flask_restplus import fields, abort
20
21
22
import zipstream
from flask import stream_with_context, Response
import sys
23
import json
Markus Scheidgen's avatar
Markus Scheidgen committed
24

25
from nomad.app.utils import RFC3339DateTime
26
from nomad.files import Restricted
27

Markus Scheidgen's avatar
Markus Scheidgen committed
28
from .api import api
Markus Scheidgen's avatar
Markus Scheidgen committed
29
30


31
32
33
34
35
36
if sys.version_info >= (3, 7):
    import zipfile
else:
    import zipfile37 as zipfile


37
38
39
40
41
42
43
44
metadata_model = api.model('MetaData', {
    'with_embargo': fields.Boolean(default=False, description='Data with embargo is only visible to the upload until the embargo period ended.'),
    'comment': fields.String(description='The comment are shown in the repository for each calculation.'),
    'references': fields.List(fields.String, descriptions='References allow to link calculations to external source, e.g. URLs.'),
    'coauthors': fields.List(fields.String, description='A list of co-authors given by user_id.'),
    'shared_with': fields.List(fields.String, description='A list of users to share calculations with given by user_id.'),
    '_upload_time': RFC3339DateTime(description='Overrride the upload time.'),
    '_uploader': fields.String(description='Override the uploader with the given user id.'),
45
    'datasets': fields.List(fields.String, description='A list of dataset ids.')
46
47
})

Markus Scheidgen's avatar
Markus Scheidgen committed
48
pagination_model = api.model('Pagination', {
49
50
    'total': fields.Integer(description='Number of total elements.'),
    'page': fields.Integer(description='Number of the current page, starting with 0.'),
51
    'per_page': fields.Integer(description='Number of elements per page.')
Markus Scheidgen's avatar
Markus Scheidgen committed
52
})
53
""" Model used in responses with pagination. """
Markus Scheidgen's avatar
Markus Scheidgen committed
54
55
56
57
58
59
60
61
62
63
64


pagination_request_parser = api.parser()
""" Parser used for requests with pagination. """

pagination_request_parser.add_argument(
    'page', type=int, help='The page, starting with 1.', location='args')
pagination_request_parser.add_argument(
    'per_page', type=int, help='Desired calcs per page.', location='args')
pagination_request_parser.add_argument(
    'order_by', type=str, help='The field to sort by.', location='args')
65
66
pagination_request_parser.add_argument(
    'order', type=int, help='Use -1 for decending and 1 for acending order.', location='args')
Markus Scheidgen's avatar
Markus Scheidgen committed
67
68
69
70
71


def calc_route(ns, prefix: str = ''):
    """ A resource decorator for /<upload>/<calc> based routes. """
    def decorator(func):
72
        ns.route('%s/<string:upload_id>/<string:calc_id>' % prefix)(
Markus Scheidgen's avatar
Markus Scheidgen committed
73
            api.doc(params={
74
                'upload_id': 'The unique id for the requested upload.',
75
                'calc_id': 'The unique id for the requested calculation.'
Markus Scheidgen's avatar
Markus Scheidgen committed
76
77
78
            })(func)
        )
    return decorator
79
80
81
82
83
84
85
86
87
88
89


def upload_route(ns, prefix: str = ''):
    """ A resource decorator for /<upload> based routes. """
    def decorator(func):
        ns.route('%s/<string:upload_id>' % prefix)(
            api.doc(params={
                'upload_id': 'The unique id for the requested upload.'
            })(func)
        )
    return decorator
90
91
92
93


def streamed_zipfile(
        files: Iterable[Tuple[str, str, Callable[[str], IO], Callable[[str], int]]],
94
        zipfile_name: str, compress: bool = False):
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    """
    Creates a response that streams the given files as a streamed zip file. Ensures that
    each given file is only streamed once, based on its filename in the resulting zipfile.

    Arguments:
        files: An iterable of tuples with the filename to be used in the resulting zipfile,
            an file id within the upload, a callable that gives an binary IO object for the
            file id, and a callable that gives the file size for the file id.
        zipfile_name: A name that will be used in the content disposition attachment
            used as an HTTP respone.
        compress: Uses compression. Default is stored only.
    """

    streamed_files: Set[str] = set()

    def generator():
        """ Stream a zip file with all files using zipstream. """
        def iterator():
            """
            Replace the directory based iter of zipstream with an iter over all given
            files.
            """
117
            # the actual contents
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
            for zipped_filename, file_id, open_io, file_size in files:
                if zipped_filename in streamed_files:
                    continue
                streamed_files.add(zipped_filename)

                # Write a file to the zipstream.
                try:
                    f = open_io(file_id)
                    try:
                        def iter_content():
                            while True:
                                data = f.read(1024 * 64)
                                if not data:
                                    break
                                yield data

                        yield dict(
                            arcname=zipped_filename, iterable=iter_content(),
                            buffer_size=file_size(file_id))
                    finally:
                        f.close()
                except KeyError:
                    # files that are not found, will not be returned
                    pass
                except Restricted:
                    # due to the streaming nature, we cannot raise 401 here
                    # we just leave it out in the download
                    pass

        compression = zipfile.ZIP_DEFLATED if compress else zipfile.ZIP_STORED
        zip_stream = zipstream.ZipFile(mode='w', compression=compression, allowZip64=True)
        zip_stream.paths_to_write = iterator()

        for chunk in zip_stream:
            yield chunk

    response = Response(stream_with_context(generator()), mimetype='application/zip')
    response.headers['Content-Disposition'] = 'attachment; filename={}'.format(zipfile_name)
    return response
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195


def to_json(files: Iterable[Tuple[str, str, Callable[[str], IO], Callable[[str], int]]]):
    data = {}
    for _, file_id, open_io, _ in files:
        try:
            f = open_io(file_id)
            data[file_id] = json.loads(f.read())
        except KeyError:
            pass
        except Restricted:
            abort(401, message='Not authorized to access %s.' % file_id)
    return data


def build_snippet(args, base_url):
    str_code = 'import requests\n'
    str_code += 'from urllib.parse import urlencode\n'
    str_code += '\n\n'
    str_code += 'def query_repository(args, base_url):\n'
    str_code += '    url = "%s?%s" % (base_url, urlencode(args))\n'
    str_code += '    response = requests.get(url)\n'
    str_code += '    if response.status_code != 200:\n'
    str_code += '        raise Exception("nomad return status %d" % response.status_code)\n'
    str_code += '    return response.json()\n'
    str_code += '\n\n'
    str_code += 'args = {'
    for key, val in args.items():
        if val is None:
            continue
        if isinstance(val, str):
            str_code += '"%s": "%s", ' % (key, val)
        else:
            str_code += '"%s": %s, ' % (key, val)
    str_code += '}\n'
    str_code += 'base_url = "%s"\n' % base_url
    str_code += 'JSON_DATA = query_repository(args, base_url)\n'

    return str_code