base.py 17.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2018 Markus Scheidgen
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an"AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Markus Scheidgen's avatar
Markus Scheidgen committed
15
from typing import List, Any
16
import logging
17
import time
Markus Scheidgen's avatar
Markus Scheidgen committed
18
from celery import Celery
19
20
from celery.signals import after_setup_task_logger, after_setup_logger, worker_process_init
from mongoengine import Document, StringField, ListField, DateTimeField, IntField, \
21
    ValidationError, BooleanField
22
23
24
25
from mongoengine.connection import MongoEngineConnectionError
from mongoengine.base.metaclasses import TopLevelDocumentMetaclass
from pymongo import ReturnDocument
from datetime import datetime
26

27
from nomad import config, utils, infrastructure
28
29
import nomad.patch  # pylint: disable=unused-import

30

31
32
33
34
35
36
37
38
if config.logstash.enabled:
    def initialize_logstash(logger=None, loglevel=logging.DEBUG, **kwargs):
        utils.add_logstash_handler(logger)
        return logger

    after_setup_task_logger.connect(initialize_logstash)
    after_setup_logger.connect(initialize_logstash)

39
40
41
42
43

@worker_process_init.connect
def setup(**kwargs):
    infrastructure.setup()

44

Markus Scheidgen's avatar
Markus Scheidgen committed
45
app = Celery('nomad.processing', broker=config.celery.broker_url)
Markus Scheidgen's avatar
Markus Scheidgen committed
46
app.conf.update(worker_hijack_root_logger=False)
47
48
49
50
51
52
53

PENDING = 'PENDING'
RUNNING = 'RUNNING'
FAILURE = 'FAILURE'
SUCCESS = 'SUCCESS'


54
55
56
class InvalidId(Exception): pass


Markus Scheidgen's avatar
Markus Scheidgen committed
57
class ProcNotRegistered(Exception): pass
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76


class ProcMetaclass(TopLevelDocumentMetaclass):
    def __new__(cls, name, bases, attrs):
        cls = super().__new__(cls, name, bases, attrs)

        tasks = []
        setattr(cls, 'tasks', tasks)

        for name, attr in attrs.items():
            task = getattr(attr, '__task_name', None)
            if task is not None and task not in tasks:
                tasks.append(task)

        return cls


class Proc(Document, metaclass=ProcMetaclass):
    """
77
    Base class for objects that are involved in processing and need persistent processing
78
79
    state.

80
81
82
83
84
85
86
87
88
89
90
91
    It solves two issues. First, distributed operation (via celery) and second keeping
    state of a chain of potentially failing processing tasks. Both are controlled via
    decorators @process and @task. Subclasses should use these decorators on their
    methods. Parameters are not supported for decorated functions. Use fields on the
    document instead.

    Processing state will be persistet at appropriate
    times and must not be persistet manually. All attributes are stored to mongodb.
    The class allows to render into a JSON serializable dict via :attr:`json_dict`.

    Possible processing states are PENDING, RUNNING, FAILURE, and SUCCESS.

92
93
94
95
96
97
98
99
100
101
102
    Attributes:
        current_task: the currently running or last completed task
        status: the overall status of the processing
        errors: a list of errors that happened during processing. Error fail a processing
            run
        warnings: a list of warnings that happened during processing. Warnings do not
            fail a processing run
        create_time: the time of creation (not the start of processing)
        proc_time: the time that processing completed (successfully or not)
    """

Markus Scheidgen's avatar
Markus Scheidgen committed
103
    meta: Any = {
104
105
106
107
108
109
        'abstract': True,
    }

    tasks: List[str] = None
    """ the ordered list of tasks that comprise a processing run """

110
    current_task = StringField(default=None)
111
112
113
114
    status = StringField(default='CREATED')

    errors = ListField(StringField())
    warnings = ListField(StringField())
115
116
117
118

    create_time = DateTimeField(required=True)
    complete_time = DateTimeField()

119
120
    _async_status = StringField(default='UNCALLED')

121
122
    @property
    def completed(self) -> bool:
123
        """ Returns True of the process has failed or succeeded. """
124
125
        return self.status in [SUCCESS, FAILURE]

Markus Scheidgen's avatar
Markus Scheidgen committed
126
127
    def get_logger(self):
        return utils.get_logger(
128
            'nomad.processing', current_task=self.current_task, process=self.__class__.__name__,
Markus Scheidgen's avatar
Markus Scheidgen committed
129
130
            status=self.status)

131
132
    @classmethod
    def create(cls, **kwargs):
133
        """ Factory method that must be used instead of regular constructor. """
134
135
136
137
        assert cls.tasks is not None and len(cls.tasks) > 0, \
            """ the class attribute tasks must be overwritten with an acutal list """
        assert 'status' not in kwargs, \
            """ do not set the status manually, its managed """
138

139
140
141
142
        kwargs.setdefault('create_time', datetime.now())
        self = cls(**kwargs)
        self.status = PENDING if self.current_task is None else RUNNING
        self.save()
143

144
        return self
145

146
    @classmethod
Markus Scheidgen's avatar
Markus Scheidgen committed
147
    def get_by_id(cls, id: str, id_field: str):
148
        try:
Markus Scheidgen's avatar
Markus Scheidgen committed
149
            obj = cls.objects(**{id_field: id}).first()
150
        except ValidationError as e:
Markus Scheidgen's avatar
Markus Scheidgen committed
151
            raise InvalidId('%s is not a valid id' % id)
152
153
        except MongoEngineConnectionError as e:
            raise e
154

155
        if obj is None:
Markus Scheidgen's avatar
Markus Scheidgen committed
156
            raise KeyError('%s with id %s does not exist' % (cls.__name__, id))
157
158
159

        return obj

Markus Scheidgen's avatar
Markus Scheidgen committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    @classmethod
    def get(cls, obj_id):
        return cls.get_by_id(str(obj_id), 'id')

    @staticmethod
    def log(logger, log_level, msg, **kwargs):
        # TODO there seems to be a bug in structlog, cannot use logger.log
        if log_level == logging.ERROR:
            logger.error(msg, **kwargs)
        elif log_level == logging.WARNING:
            logger.warning(msg, **kwargs)
        elif log_level == logging.INFO:
            logger.info(msg, **kwargs)
        elif log_level == logging.DEBUG:
            logger.debug(msg, **kwargs)
        else:
            logger.critical(msg, **kwargs)

    def fail(self, *errors, log_level=logging.ERROR, **kwargs):
179
        """ Allows to fail the process. Takes strings or exceptions as args. """
Markus Scheidgen's avatar
Markus Scheidgen committed
180
181
182
        assert not self.completed, 'Cannot fail a completed process.'

        failed_with_exception = False
183
184

        self.status = FAILURE
Markus Scheidgen's avatar
Markus Scheidgen committed
185
186
187
188
189
190
191

        logger = self.get_logger(**kwargs)
        for error in errors:
            if isinstance(error, Exception):
                failed_with_exception = True
                Proc.log(logger, log_level, 'task failed with exception', exc_info=error, **kwargs)

192
        self.errors = [str(error) for error in errors]
193
194
        self.complete_time = datetime.now()

Markus Scheidgen's avatar
Markus Scheidgen committed
195
196
197
198
        if not failed_with_exception:
            errors_str = "; ".join([str(error) for error in errors])
            Proc.log(logger, log_level, 'task failed', errors=errors_str, **kwargs)

199
        logger.info('process failed')
Markus Scheidgen's avatar
Markus Scheidgen committed
200

201
202
        self.save()

Markus Scheidgen's avatar
Markus Scheidgen committed
203
    def warning(self, *warnings, log_level=logging.warning, **kwargs):
204
        """ Allows to save warnings. Takes strings or exceptions as args. """
205
206
        assert not self.completed

Markus Scheidgen's avatar
Markus Scheidgen committed
207
208
        logger = self.get_logger(**kwargs)

209
        for warning in warnings:
Markus Scheidgen's avatar
Markus Scheidgen committed
210
211
            warning = str(warning)
            self.warnings.append(warning)
212
            Proc.log(logger, log_level, 'task with warning', warning=warning)
213

214
    def _continue_with(self, task):
215
        tasks = self.__class__.tasks
Markus Scheidgen's avatar
Markus Scheidgen committed
216
        assert task in tasks, 'task %s must be one of the classes tasks %s' % (task, str(tasks))  # pylint: disable=E1135
217
        if self.current_task is None:
Markus Scheidgen's avatar
Markus Scheidgen committed
218
            assert task == tasks[0], "process has to start with first task"  # pylint: disable=E1136
219
220
221
222
223
224
225
226
227
        else:
            assert tasks.index(task) == tasks.index(self.current_task) + 1, \
                "tasks must be processed in the right order"

        if self.status == FAILURE:
            return False

        if self.status == PENDING:
            assert self.current_task is None
Markus Scheidgen's avatar
Markus Scheidgen committed
228
            assert task == tasks[0]  # pylint: disable=E1136
229
            self.status = RUNNING
Markus Scheidgen's avatar
Markus Scheidgen committed
230
            self.current_task = task
231
            self.get_logger().info('started process')
Markus Scheidgen's avatar
Markus Scheidgen committed
232
233
        else:
            self.current_task = task
234
            self.get_logger().info('successfully completed task')
235
236
237
238

        self.save()
        return True

239
    def _complete(self):
240
        if self.status != FAILURE:
Markus Scheidgen's avatar
Markus Scheidgen committed
241
            assert self.status == RUNNING, 'Can only complete a running process.'
242
243
            self.status = SUCCESS
            self.save()
244
            self.get_logger().info('completed process')
245

246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
    def block_until_complete(self, interval=0.01):
        """
        Reloads the process constrantly until it sees a completed process. Should be
        used with care as it can block indefinetly. Just intended for testing purposes.
        """
        while not self.completed:
            time.sleep(interval)
            self.reload()

    @property
    def json_dict(self) -> dict:
        """ A json serializable dictionary representation. """
        data = {
            'tasks': getattr(self.__class__, 'tasks'),
            'current_task': self.current_task,
            'status': self.status,
            'completed': self.completed,
            'errors': self.errors,
            'warnings': self.warnings,
            'create_time': self.create_time.isoformat() if self.create_time is not None else None,
            'complete_time': self.complete_time.isoformat() if self.complete_time is not None else None,
            '_async_status': self._async_status
        }
        return {key: value for key, value in data.items() if value is not None}


class InvalidChordUsage(Exception): pass


class Chord(Proc):
    """
    A special Proc base class that manages a chord of child processes. It saves some
    attional state to track child processes and provides methods to control that
    state.

    It uses a counter approach with atomic updates to trac the number of processed
    children.

    TODO the joined attribute is not stricly necessary and only serves debugging purposes.
    Maybe it should be removed, since it also requires another save.

    TODO it is vital that sub classes and children don't miss any calls. This might
    not be practical, because in reality processes might even fail to fail.

    Attributes:
        total_children (int): the number of spawed children, -1 denotes that number was not
            saved yet
        completed_children (int): the number of completed child procs
        joined (bool): true if all children are completed and the join method was already called
    """
    total_children = IntField(default=-1)
    completed_children = IntField(default=0)
    joined = BooleanField(default=False)

    meta = {
        'abstract': True
    }

    def spwaned_childred(self, total_children=1):
        """
        Subclasses must call this method after all childred have been spawned.

        Arguments:
            total_children (int): the number of spawned children
        """
        self.total_children = total_children
312
        self.modify(total_children=self.total_children)
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
        self._check_join(children=0)

    def completed_child(self):
        """ Children must call this, when they completed processig. """
        self._check_join(children=1)

    def _check_join(self, children):
        # incr the counter and get reference values atomically
        completed_children, others = self.incr_counter(
            'completed_children', children, ['total_children', 'joined'])
        total_children, joined = others

        self.get_logger().debug(
            'Check for join', total_children=total_children,
            completed_children=completed_children, joined=joined)

        # check the join condition and raise errors if chord is in bad state
        if completed_children == total_children:
            if not joined:
                self.join()
                self.joined = True
334
                self.modify(joined=self.joined)
335
336
337
338
339
340
341
342
343
344
                self.get_logger().debug('Chord is joined')
            else:
                raise InvalidChordUsage('Chord cannot be joined twice.')
        elif completed_children > total_children and total_children != -1:
            raise InvalidChordUsage('Chord counter is out of limits.')

    def join(self):
        """ Subclasses might overwrite to do something after all children have completed. """
        pass

345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
    def incr_counter(self, field, value=1, other_fields=None):
        """
        Atomically increases the given field by value and return the new value.
        Optionally return also other values from the updated object to avoid
        reloads.

        Arguments:
            field: the name of the field to increment, must be a :class:`IntField`
            value: the value to increment the field by, default is 1
            other_fields: an optional list of field names that should also be returned

        Returns:
            either the value of the updated field, or a tuple with this value and a list
            of value for the given other fields
        """
        # use a primitive but atomic pymongo call
        updated_raw = self._get_collection().find_one_and_update(
            {'_id': self.id},
            {'$inc': {field: value}},
            return_document=ReturnDocument.AFTER)

        if updated_raw is None:
            raise KeyError('object does not exist, was probaly not yet written to db')

        if other_fields is None:
            return updated_raw[field]
        else:
            return updated_raw[field], [updated_raw[field] for field in other_fields]

374

375
def task(func):
376
377
378
379
380
381
382
383
    """
    The decorator for tasks that will be wrapped in excaption handling that will fail the process.
    The task methods of a :class:`Proc` class/document comprise a sequence
    (order of methods in class namespace) of tasks. Tasks must be executed in that order.
    Completion of the last task, will put the :class:`Proc` instance into the
    SUCCESS state. Calling the first task will put it into RUNNING state. Tasks will
    only be exectued, if the process has not yet reached FAILURE state.
    """
384
385
386
    def wrapper(self, *args, **kwargs):
        if self.status == 'FAILURE':
            return
387

388
        self._continue_with(func.__name__)
389
        try:
390
391
392
            func(self, *args, **kwargs)
        except Exception as e:
            self.fail(e)
393

Markus Scheidgen's avatar
Markus Scheidgen committed
394
        if self.__class__.tasks[-1] == self.current_task and not self.completed:
395
            self._complete()
396

397
398
399
    setattr(wrapper, '__task_name', func.__name__)
    wrapper.__name__ = func.__name__
    return wrapper
400
401


402
403
404
405
406
def all_subclasses(cls):
    """ Helper method to calculate set of all subclasses of a given class. """
    return set(cls.__subclasses__()).union(
        [s for c in cls.__subclasses__() for s in all_subclasses(c)])

Markus Scheidgen's avatar
Markus Scheidgen committed
407

408
409
410
411
all_proc_cls = {cls.__name__: cls for cls in all_subclasses(Proc)}
""" Name dictionary for all Proc classes. """


412
@app.task(bind=True, ignore_results=True, max_retries=3)
413
def proc_task(task, cls_name, self_id, func_attr):
414
415
416
417
418
419
420
    """
    The celery task that is used to execute async process functions.
    It ignores results, since all results are handled via the self document.
    It retries for 3 times with a countdown of 3 on missing 'selfs', since this
    might happen in sharded, distributed mongo setups where the object might not
    have yet been propagated and therefore apear missing.
    """
421
    logger = utils.get_logger('__name__', cls=cls_name, id=self_id, func=func_attr)
422

423
    # get the process class
424
    logger.debug('received process function call')
425
426
427
428
429
430
431
    global all_proc_cls
    cls = all_proc_cls.get(cls_name, None)
    if cls is None:
        # refind all Proc classes, since more modules might have been imported by now
        all_proc_cls = {cls.__name__: cls for cls in all_subclasses(Proc)}
        cls = all_proc_cls.get(cls_name, None)

432
    if cls is None:
433
        logger.error('document not a subcass of Proc')
Markus Scheidgen's avatar
Markus Scheidgen committed
434
        raise ProcNotRegistered('document %s not a subclass of Proc' % cls_name)
435

436
    # get the process instance
437
    try:
438
        self = cls.get(self_id)
439
    except KeyError as e:
440
441
        logger.warning('called object is missing')
        raise task.retry(exc=e, countdown=3)
442

443
    # get the process function
444
445
    func = getattr(self, func_attr, None)
    if func is None:
446
        logger.error('called function not a function of proc class')
447
        self.fail('called function %s is not a function of proc class %s' % (func_attr, cls_name))
448
449
        return

450
    # unwrap the process decorator
451
452
    func = getattr(func, '__process_unwrapped', None)
    if func is None:
453
        logger.error('called function was not decorated with @process')
Markus Scheidgen's avatar
Markus Scheidgen committed
454
        self.fail('called function %s was not decorated with @process' % func_attr)
455
456
        return

457
    # call the process function
458
    try:
459
        self._async_status = 'RECEIVED-%s' % func.__name__
460
        func(self)
461
462
463
464
    except Exception as e:
        self.fail(e)


465
def process(func):
466
467
468
469
470
471
    """
    The decorator for process functions that will be called async via celery.
    All calls to the decorated method will result in celery task requests.
    To transfer state, the instance will be saved to the database and loading on
    the celery task worker. Process methods can call other (process) functions/methods.
    """
472
473
    def wrapper(self, *args, **kwargs):
        assert len(args) == 0 and len(kwargs) == 0, 'process functions must not have arguments'
474
        self._async_status = 'CALLED-%s' % func.__name__
475
476
        self.save()

477
        self_id = self.id.__str__()
478
        cls_name = self.__class__.__name__
479

480
481
        logger = utils.get_logger(__name__, cls=cls_name, id=self_id, func=func.__name__)
        logger.debug('calling process function')
482
        return proc_task.s(cls_name, self_id, func.__name__).delay()
483

484
485
486
487
488
    task = getattr(func, '__task_name', None)
    if task is not None:
        setattr(wrapper, '__task_name', task)
    wrapper.__name__ = func.__name__
    setattr(wrapper, '__process_unwrapped', func)
489

490
    return wrapper