__init__.py 15.5 KB
Newer Older
Markus Scheidgen's avatar
Markus Scheidgen committed
1
2
3
4
#
# Copyright The NOMAD Authors.
#
# This file is part of NOMAD. See https://nomad-lab.eu for further info.
Markus Scheidgen's avatar
Markus Scheidgen committed
5
6
7
8
9
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
Markus Scheidgen's avatar
Markus Scheidgen committed
10
#     http://www.apache.org/licenses/LICENSE-2.0
Markus Scheidgen's avatar
Markus Scheidgen committed
11
12
#
# Unless required by applicable law or agreed to in writing, software
Markus Scheidgen's avatar
Markus Scheidgen committed
13
# distributed under the License is distributed on an "AS IS" BASIS,
Markus Scheidgen's avatar
Markus Scheidgen committed
14
15
16
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Markus Scheidgen's avatar
Markus Scheidgen committed
17
#
Markus Scheidgen's avatar
Markus Scheidgen committed
18

19
'''
Markus Scheidgen's avatar
Markus Scheidgen committed
20
21
.. autofunc::nomad.utils.create_uuid
.. autofunc::nomad.utils.hash
Markus Scheidgen's avatar
Markus Scheidgen committed
22
.. autofunc::nomad.utils.timer
Markus Scheidgen's avatar
Markus Scheidgen committed
23
24
25

Logging in nomad is structured. Structured logging means that log entries contain
dictionaries with quantities related to respective events. E.g. having the code,
26
parser, parser version, calc_id, mainfile, etc. for all events that happen during
Markus Scheidgen's avatar
Markus Scheidgen committed
27
28
29
30
31
32
33
calculation processing. This means the :func:`get_logger` and all logger functions
take keyword arguments for structured data. Otherwise :func:`get_logger` can
be used similar to the standard *logging.getLogger*.

Depending on the configuration all logs will also be send to a central logstash.

.. autofunc::nomad.utils.get_logger
34
35
36
37
.. autofunc::nomad.utils.hash
.. autofunc::nomad.utils.create_uuid
.. autofunc::nomad.utils.timer
.. autofunc::nomad.utils.lnr
38
.. autofunc::nomad.utils.strip
39
'''
Markus Scheidgen's avatar
Markus Scheidgen committed
40

41
42
from typing import List, Iterable
from collections import OrderedDict
43
import base64
44
from contextlib import contextmanager
45
import json
Markus Scheidgen's avatar
Markus Scheidgen committed
46
import uuid
47
import time
48
import hashlib
49
50
import sys
from datetime import timedelta
51
import collections
52
import logging
53
import inspect
Markus Scheidgen's avatar
Markus Scheidgen committed
54
import orjson
55
56
import resource
import os
57

58
59
from nomad import config

Markus Scheidgen's avatar
Markus Scheidgen committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

def dump_json(data):
    def default(data):
        if isinstance(data, collections.OrderedDict):
            return dict(data)

        if data.__class__.__name__ == 'BaseList':
            return list(data)

        raise TypeError

    return orjson.dumps(
        data, default=default,
        option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS)


76
default_hash_len = 28
77
''' Length of hashes and hash-based ids (e.g. calc, upload) in nomad. '''
Markus Scheidgen's avatar
Markus Scheidgen committed
78

79
80
81
try:
    from . import structlogging
    from .structlogging import legacy_logger
82
    from .structlogging import configure_logging
83
84
85
86
87
88
89
90
91
92
93
94

    def get_logger(name, **kwargs):
        '''
        Returns a structlog logger that is already attached with a logstash handler.
        Use additional *kwargs* to pre-bind some values to all events.
        '''
        return structlogging.get_logger(name, **kwargs)

except ImportError:
    def get_logger(name, **kwargs):
        return ClassicLogger(name, **kwargs)

95
96
97
98
    def configure_logging(console_log_level=config.console_log_level):
        import logging
        logging.basicConfig(level=console_log_level)

99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134

class ClassicLogger:
    '''
    A logger class that emulates the structlog interface, but uses the classical
    build-in Python logging.
    '''
    def __init__(self, name, **kwargs):
        self.kwargs = kwargs
        self.logger = logging.getLogger(name)

    def bind(self, **kwargs):
        all_kwargs = dict(self.kwargs)
        all_kwargs.update(**kwargs)
        return ClassicLogger(self.logger.name, **all_kwargs)

    def __log(self, method_name, event, **kwargs):
        method = getattr(self.logger, method_name)
        all_kwargs = dict(self.kwargs)
        all_kwargs.update(**kwargs)

        message = '%s (%s)' % (
            event,
            ', '.join(['%s=%s' % (str(key), str(value)) for key, value in all_kwargs.items()])
        )
        method(message)

    def __getattr__(self, key):
        return lambda *args, **kwargs: self.__log(key, *args, **kwargs)


def set_console_log_level(level):
    root = logging.getLogger()
    for handler in root.handlers:
        if isinstance(handler, (logging.StreamHandler, logging.FileHandler)):
            handler.setLevel(level)

135

136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def decode_handle_id(handle_str: str):
    result = 0
    for c in handle_str:
        ordinal = ord(c.lower())
        if 48 <= ordinal <= 57:
            number = ordinal - 48
        elif 97 <= ordinal <= 118:
            number = ordinal - 87
        else:
            raise ValueError()

        result = result * 32 + number

    return result


152
def hash(*args, length: int = default_hash_len) -> str:
153
    ''' Creates a websafe hash of the given length based on the repr of the given arguments. '''
154
155
156
157
158
159
160
161
    hash = hashlib.sha512()
    for arg in args:
        hash.update(str(arg).encode('utf-8'))

    return make_websave(hash, length=length)


def make_websave(hash, length: int = default_hash_len) -> str:
162
    ''' Creates a websafe string for a hashlib hash object. '''
Markus Scheidgen's avatar
Markus Scheidgen committed
163
    if length > 0:
164
        return base64.b64encode(hash.digest(), altchars=b'-_')[:length].decode('utf-8')
Markus Scheidgen's avatar
Markus Scheidgen committed
165
    else:
166
        return base64.b64encode(hash.digest(), altchars=b'-_')[0:-2].decode('utf-8')
Markus Scheidgen's avatar
Markus Scheidgen committed
167
168


169
def base64_encode(string):
170
    '''
171
    Removes any `=` used as padding from the encoded string.
172
    '''
173
174
175
176
177
    encoded = base64.urlsafe_b64encode(string).decode('utf-8')
    return encoded.rstrip("=")


def base64_decode(string):
178
    '''
179
    Adds back in the required padding before decoding.
180
    '''
181
182
183
184
185
    padding = 4 - (len(string) % 4)
    bytes = (string + ("=" * padding)).encode('utf-8')
    return base64.urlsafe_b64decode(bytes)


Markus Scheidgen's avatar
Markus Scheidgen committed
186
def create_uuid() -> str:
187
    ''' Returns a web-save base64 encoded random uuid (type 4). '''
Markus Scheidgen's avatar
Markus Scheidgen committed
188
189
190
    return base64.b64encode(uuid.uuid4().bytes, altchars=b'-_').decode('utf-8')[0:-2]


191
192
193
194
195
196
197
def adjust_uuid_size(uuid, length: int = default_hash_len):
    ''' Adds prefixing spaces to a uuid to ensure the default uuid length. '''
    uuid = uuid.rjust(length, ' ')
    assert len(uuid) == length, 'uuids must have the right fixed size'
    return uuid


Markus Scheidgen's avatar
Markus Scheidgen committed
198
199
@contextmanager
def lnr(logger, event, **kwargs):
200
    '''
Markus Scheidgen's avatar
Markus Scheidgen committed
201
    A context manager that Logs aNd Raises all exceptions with the given logger.
Markus Scheidgen's avatar
Markus Scheidgen committed
202
203
204
205
206

    Arguments:
        logger: The logger that should be used for logging exceptions.
        event: the log message
        **kwargs: additional properties for the structured log
207
    '''
Markus Scheidgen's avatar
Markus Scheidgen committed
208
209
    try:
        yield
210

Markus Scheidgen's avatar
Markus Scheidgen committed
211
    except Exception as e:
Markus Scheidgen's avatar
Markus Scheidgen committed
212
        # ignore HTTPException as they are part of the normal app error handling
213
        if e.__class__.__name__ != 'HTTPException':
214
            logger.error(event, exc_info=e, **kwargs)
Markus Scheidgen's avatar
Markus Scheidgen committed
215
        raise e
216
217
218


@contextmanager
219
def timer(logger, event, method='info', lnr_event: str = None, log_memory: bool = False, **kwargs):
220
    '''
Markus Scheidgen's avatar
Markus Scheidgen committed
221
222
223
224
225
    A context manager that takes execution time and produces a log entry with said time.

    Arguments:
        logger: The logger that should be used to produce the log entry.
        event: The log message/event.
226
        method: The log method that should be used. Must be a valid logger method name.
Markus Scheidgen's avatar
Markus Scheidgen committed
227
            Default is 'info'.
228
        log_memory: Log process memory usage before and after.
229
        **kwargs: Additional logger data that is passed to the log entry.
Markus Scheidgen's avatar
Markus Scheidgen committed
230
231
232

    Returns:
        The method yields a dictionary that can be used to add further log data.
233
    '''
234
    kwargs = dict(kwargs)
235
    start = time.time()
236
237
238
239
    if log_memory:
        rss_before = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
        kwargs['pid'] = os.getpid()
        kwargs['exec_rss_before'] = rss_before
240
241

    try:
Markus Scheidgen's avatar
Markus Scheidgen committed
242
        yield kwargs
243
244
245
    except Exception as e:
        if lnr_event is not None:
            stop = time.time()
246
247
248
249
            if log_memory:
                rss_after = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
                kwargs['exec_rss_after'] = rss_after
                kwargs['exec_rss_delta'] = rss_before - rss_after
250
251
            logger.error(lnr_event, exc_info=e, exec_time=stop - start, **kwargs)
        raise e
252
253
    finally:
        stop = time.time()
254
255
256
257
        if log_memory:
            rss_after = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
            kwargs['exec_rss_after'] = rss_after
            kwargs['exec_rss_delta'] = rss_before - rss_after
258

259
260
261
262
    if logger is None:
        print(event, stop - start)
        return

263
264
    logger_method = getattr(logger, 'info', None)
    if logger_method is not None:
265
        logger_method(event, exec_time=stop - start, **kwargs)
266
    else:
267
        logger.error('Unknown logger method %s.' % method)
268
269
270
271


class archive:
    @staticmethod
272
273
    def create(upload_id: str, calc_id: str) -> str:
        return '%s/%s' % (upload_id, calc_id)
274
275
276
277
278
279
280
281
282
283

    @staticmethod
    def items(archive_id: str) -> List[str]:
        return archive_id.split('/')

    @staticmethod
    def item(archive_id: str, index: int) -> str:
        return archive.items(archive_id)[index]

    @staticmethod
284
    def calc_id(archive_id: str) -> str:
285
286
287
        return archive.item(archive_id, 1)

    @staticmethod
288
    def upload_id(archive_id: str) -> str:
289
        return archive.item(archive_id, 0)
290
291
292
293


def to_tuple(self, *args):
    return tuple(self[arg] for arg in args)
294
295


296
def chunks(list, n):
297
    ''' Chunks up the given list into parts of size n. '''
298
299
300
301
    for i in range(0, len(list), n):
        yield list[i:i + n]


302
class SleepTimeBackoff:
303
    '''
304
305
    Provides increasingly larger sleeps. Useful when
    observing long running processes with unknown runtime.
306
    '''
307

308
    def __init__(self, start_time: float = 0.1, max_time: float = 5):
309
310
311
312
313
314
315
316
317
318
        self.current_time = start_time
        self.max_time = max_time

    def __call__(self):
        self.sleep()

    def sleep(self):
        time.sleep(self.current_time)
        self.current_time *= 2
        self.current_time = min(self.max_time, self.current_time)
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346


class ETA:
    def __init__(self, total: int, message: str, interval: int = 1000):
        self.start = time.time()
        self.total = total
        self.count = 0
        self.interval = interval
        self.interval_count = 0
        self.message = message

    def add(self, amount: int = 1):
        self.count += amount
        interval_count = int(self.count / self.interval)
        if interval_count > self.interval_count:
            self.interval_count = interval_count
            delta_t = time.time() - self.start
            eta = delta_t * (self.total - self.count) / self.count
            eta_str = str(timedelta(seconds=eta))
            sys.stdout.write('\r' + (self.message % (self.count, self.total, eta_str)))
            sys.stdout.flush()

    def __enter__(self):
        self.start = time.time()
        return self

    def __exit__(self, *args, **kwargs):
        print('')
347
348
349


def common_prefix(paths):
350
    '''
351
352
    Computes the longest common file path prefix (with respect to '/' separated segments).
    Returns empty string is ne common prefix exists.
353
    '''
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
    common_prefix = None

    for path in paths:
        if common_prefix is None:
            common_prefix = path

        index = 0
        index_last_slash = -1
        for a, b in zip(path, common_prefix):
            if a != b:
                break
            if a == '/':
                index_last_slash = index
            index += 1

        if index_last_slash == -1:
            common_prefix = ''
            break

        common_prefix = common_prefix[:index_last_slash + 1]

    if common_prefix is None:
        common_prefix = ''

    return common_prefix
379
380
381
382
383
384


class RestrictedDict(OrderedDict):
    """Dictionary-like container with predefined set of mandatory and optional
    keys and a set of forbidden values.
    """
385
    def __init__(self, mandatory_keys: Iterable = None, optional_keys: Iterable = None, forbidden_values: Iterable = None, lazy: bool = True):
386
387
388
389
        """
        Args:
            mandatory_keys: Keys that have to be present.
            optional_keys: Keys that are optional.
390
            forbidden_values: Values that are forbidden. Only supports hashable values.
391
            lazy: If false, the values are checked already when inserting. If
392
                True, the values should be manually checked by calling the
393
394
395
                check()-function.
        """
        super().__init__()
396
397
398
399
400

        if isinstance(mandatory_keys, (list, tuple, set)):
            self._mandatory_keys = set(mandatory_keys)
        elif mandatory_keys is None:
            self._mandatory_keys = set()
401
        else:
402
403
404
405
406
407
            raise ValueError("Please provide the mandatory_keys as a list, tuple or set.")

        if isinstance(optional_keys, (list, tuple, set)):
            self._optional_keys = set(optional_keys)
        elif optional_keys is None:
            self._optional_keys = set()
408
        else:
409
410
411
            raise ValueError("Please provide the optional_keys as a list, tuple or set.")

        if isinstance(forbidden_values, (list, tuple, set)):
412
            self._forbidden_values = set(forbidden_values)
413
        elif forbidden_values is None:
414
            self._forbidden_values = set()
415
416
417
        else:
            raise ValueError("Please provide the forbidden_values as a list or tuple of values.")

418
419
420
421
        self._lazy = lazy

    def __setitem__(self, key, value):
        if not self._lazy:
422
423
424

            # Check that only the defined keys are used
            if key not in self._mandatory_keys and key not in self._optional_keys:
425
                raise KeyError("The key '{}' is not allowed.".format(key))
426
427
428
429
430
431
432
433

            # Check that forbidden values are not used.
            try:
                match = value in self._forbidden_values
            except TypeError:
                pass  # Unhashable value will not match
            else:
                if match:
434
                    raise ValueError("The value '{}' is not allowed.".format(key))
435

436
437
438
439
440
        super().__setitem__(key, value)

    def check(self, recursive=False):
        # Check that only the defined keys are used
        for key in self.keys():
441
            if key not in self._mandatory_keys and key not in self._optional_keys:
442
                raise KeyError("The key '{}' is not allowed.".format(key))
443
444

        # Check that all mandatory values are all defined
445
        for key in self._mandatory_keys:
446
            if key not in self:
447
                raise KeyError("The mandatory key '{}' is not present.".format(key))
448
449

        # Check that forbidden values are not used.
450
451
        for key, value in self.items():
            match = False
452
453
454
455
456
457
            try:
                match = value in self._forbidden_values
            except TypeError:
                pass  # Unhashable value will not match
            else:
                if match:
458
                    raise ValueError("The value '{}' is not allowed but was set for key '{}'.".format(value, key))
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475

        # Check recursively
        if recursive:
            for value in self.values():
                if isinstance(value, RestrictedDict):
                    value.check(recursive)

    def update(self, other):
        for key, value in other.items():
            self.__setitem__(key, value)

    def hash(self) -> str:
        """Creates a hash code from the contents. Ensures consistent ordering.
        """
        hash_str = json.dumps(self, sort_keys=True)

        return hash(hash_str)
476
477
478
479
480


def strip(docstring):
    ''' Removes any unnecessary whitespaces from a multiline doc string or description. '''
    return inspect.cleandoc(docstring)
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501


def flat(obj, prefix=None):
    '''
    Helper that translates nested dict objects into flattened dicts with
    ``key.key....`` as keys.
    '''
    if isinstance(obj, dict):
        result = {}
        for key, value in obj.items():
            if isinstance(value, dict):
                value = flat(value)
                for child_key, child_value in value.items():
                    result['%s.%s' % (key, child_key)] = child_value

            else:
                result[key] = value

        return result
    else:
        return obj