base.py 19.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2018 Markus Scheidgen
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an"AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Markus Scheidgen's avatar
Markus Scheidgen committed
15
from typing import List, Any
16
import logging
17
import time
Markus Scheidgen's avatar
Markus Scheidgen committed
18
from celery import Celery
19
20
from celery.signals import after_setup_task_logger, after_setup_logger, worker_process_init
from mongoengine import Document, StringField, ListField, DateTimeField, IntField, \
21
    ValidationError, BooleanField
22
23
24
25
from mongoengine.connection import MongoEngineConnectionError
from mongoengine.base.metaclasses import TopLevelDocumentMetaclass
from pymongo import ReturnDocument
from datetime import datetime
26

27
from nomad import config, utils, infrastructure
28
29
import nomad.patch  # pylint: disable=unused-import

30

31
if config.logstash.enabled:
Markus Scheidgen's avatar
Markus Scheidgen committed
32
33
    utils.configure_logging()

34
35
36
37
38
39
40
    def initialize_logstash(logger=None, loglevel=logging.DEBUG, **kwargs):
        utils.add_logstash_handler(logger)
        return logger

    after_setup_task_logger.connect(initialize_logstash)
    after_setup_logger.connect(initialize_logstash)

41
42
43
44
45

@worker_process_init.connect
def setup(**kwargs):
    infrastructure.setup()

46

Markus Scheidgen's avatar
Markus Scheidgen committed
47
app = Celery('nomad.processing', broker=config.celery.broker_url)
Markus Scheidgen's avatar
Markus Scheidgen committed
48
app.conf.update(worker_hijack_root_logger=False)
49
app.conf.update(task_reject_on_worker_lost=True)
50

51
CREATED = 'CREATED'
52
53
54
55
56
PENDING = 'PENDING'
RUNNING = 'RUNNING'
FAILURE = 'FAILURE'
SUCCESS = 'SUCCESS'

57
PROCESS_CALLED = 'CALLED'
58
59
60
PROCESS_RUNNING = 'RUNNING'
PROCESS_COMPLETED = 'COMPLETED'

61

62
63
64
class InvalidId(Exception): pass


Markus Scheidgen's avatar
Markus Scheidgen committed
65
class ProcNotRegistered(Exception): pass
66
67


68
69
70
class ProcessAlreadyRunning(Exception): pass


71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
class ProcMetaclass(TopLevelDocumentMetaclass):
    def __new__(cls, name, bases, attrs):
        cls = super().__new__(cls, name, bases, attrs)

        tasks = []
        setattr(cls, 'tasks', tasks)

        for name, attr in attrs.items():
            task = getattr(attr, '__task_name', None)
            if task is not None and task not in tasks:
                tasks.append(task)

        return cls


class Proc(Document, metaclass=ProcMetaclass):
    """
88
    Base class for objects that are involved in processing and need persistent processing
89
90
    state.

91
92
93
94
95
96
97
98
99
100
101
    It solves two issues. First, distributed operation (via celery) and second keeping
    state of a chain of potentially failing processing tasks. Both are controlled via
    decorators @process and @task. Subclasses should use these decorators on their
    methods. Parameters are not supported for decorated functions. Use fields on the
    document instead.

    Processing state will be persistet at appropriate
    times and must not be persistet manually. All attributes are stored to mongodb.

    Possible processing states are PENDING, RUNNING, FAILURE, and SUCCESS.

102
103
    Attributes:
        current_task: the currently running or last completed task
104
        tasks_status: the overall status of the processing
105
106
107
108
109
        errors: a list of errors that happened during processing. Error fail a processing
            run
        warnings: a list of warnings that happened during processing. Warnings do not
            fail a processing run
        create_time: the time of creation (not the start of processing)
110
111
112
        complete_time: the time that processing completed (successfully or not)
        current_process: the currently or last run asyncronous process
        process_status: the status of the currently or last run asyncronous process
113
114
    """

Markus Scheidgen's avatar
Markus Scheidgen committed
115
    meta: Any = {
116
117
118
119
120
121
        'abstract': True,
    }

    tasks: List[str] = None
    """ the ordered list of tasks that comprise a processing run """

122
    current_task = StringField(default=None)
123
124
125
    tasks_status = StringField(default=CREATED)
    create_time = DateTimeField(required=True)
    complete_time = DateTimeField()
126
127
128

    errors = ListField(StringField())
    warnings = ListField(StringField())
129

130
131
    current_process = StringField(default=None)
    process_status = StringField(default=None)
132

133
134
    _celery_task_id = StringField(default=None)

135
    @property
136
    def tasks_running(self) -> bool:
137
        """ Returns True of the process has failed or succeeded. """
138
        return self.tasks_status not in [SUCCESS, FAILURE]
139
140
141
142
143

    @property
    def process_running(self) -> bool:
        """ Returns True of an asynchrounous process is currently running. """
        return self.process_status is not None and self.process_status != PROCESS_COMPLETED
144

Markus Scheidgen's avatar
Markus Scheidgen committed
145
146
    def get_logger(self):
        return utils.get_logger(
147
148
149
            'nomad.processing', current_task=self.current_task, proc=self.__class__.__name__,
            current_process=self.current_process, process_status=self.process_status,
            tasks_status=self.tasks_status)
Markus Scheidgen's avatar
Markus Scheidgen committed
150

151
152
    @classmethod
    def create(cls, **kwargs):
153
        """ Factory method that must be used instead of regular constructor. """
154
        assert cls.tasks is not None and len(cls.tasks) > 0, \
Markus Scheidgen's avatar
Typos.    
Markus Scheidgen committed
155
            """ the class attribute tasks must be overwritten with an actual list """
156
        assert 'tasks_status' not in kwargs, \
157
            """ do not set the status manually, its managed """
158

159
160
        kwargs.setdefault('create_time', datetime.now())
        self = cls(**kwargs)
161
        self.tasks_status = PENDING if self.current_task is None else RUNNING
162
        self.save()
163

164
        return self
165

166
    @classmethod
Markus Scheidgen's avatar
Markus Scheidgen committed
167
    def get_by_id(cls, id: str, id_field: str):
168
        try:
Markus Scheidgen's avatar
Markus Scheidgen committed
169
            obj = cls.objects(**{id_field: id}).first()
170
        except ValidationError as e:
Markus Scheidgen's avatar
Markus Scheidgen committed
171
            raise InvalidId('%s is not a valid id' % id)
172
173
        except MongoEngineConnectionError as e:
            raise e
174

175
        if obj is None:
Markus Scheidgen's avatar
Markus Scheidgen committed
176
            raise KeyError('%s with id %s does not exist' % (cls.__name__, id))
177
178
179

        return obj

Markus Scheidgen's avatar
Markus Scheidgen committed
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    @classmethod
    def get(cls, obj_id):
        return cls.get_by_id(str(obj_id), 'id')

    @staticmethod
    def log(logger, log_level, msg, **kwargs):
        # TODO there seems to be a bug in structlog, cannot use logger.log
        if log_level == logging.ERROR:
            logger.error(msg, **kwargs)
        elif log_level == logging.WARNING:
            logger.warning(msg, **kwargs)
        elif log_level == logging.INFO:
            logger.info(msg, **kwargs)
        elif log_level == logging.DEBUG:
            logger.debug(msg, **kwargs)
        else:
            logger.critical(msg, **kwargs)

    def fail(self, *errors, log_level=logging.ERROR, **kwargs):
199
        """ Allows to fail the process. Takes strings or exceptions as args. """
Markus Scheidgen's avatar
Markus Scheidgen committed
200
        assert self.process_running or self.tasks_running, 'Cannot fail a completed process.'
Markus Scheidgen's avatar
Markus Scheidgen committed
201
202

        failed_with_exception = False
203

204
        self.tasks_status = FAILURE
Markus Scheidgen's avatar
Markus Scheidgen committed
205
206
207
208
209

        logger = self.get_logger(**kwargs)
        for error in errors:
            if isinstance(error, Exception):
                failed_with_exception = True
210
                Proc.log(logger, log_level, 'task failed with exception', exc_info=error)
Markus Scheidgen's avatar
Markus Scheidgen committed
211

212
        self.errors = [str(error) for error in errors]
213
214
        self.complete_time = datetime.now()

Markus Scheidgen's avatar
Markus Scheidgen committed
215
216
        if not failed_with_exception:
            errors_str = "; ".join([str(error) for error in errors])
217
            Proc.log(logger, log_level, 'task failed', errors=errors_str)
Markus Scheidgen's avatar
Markus Scheidgen committed
218

219
        logger.info('process failed')
Markus Scheidgen's avatar
Markus Scheidgen committed
220

221
222
        self.save()

Markus Scheidgen's avatar
Markus Scheidgen committed
223
    def warning(self, *warnings, log_level=logging.WARNING, **kwargs):
224
        """ Allows to save warnings. Takes strings or exceptions as args. """
Markus Scheidgen's avatar
Markus Scheidgen committed
225
        assert self.process_running or self.tasks_running
226

Markus Scheidgen's avatar
Markus Scheidgen committed
227
228
        logger = self.get_logger(**kwargs)

229
        for warning in warnings:
Markus Scheidgen's avatar
Markus Scheidgen committed
230
231
            warning = str(warning)
            self.warnings.append(warning)
232
            Proc.log(logger, log_level, 'task with warning', warning=warning)
233

234
    def _continue_with(self, task):
235
        tasks = self.__class__.tasks
Markus Scheidgen's avatar
Markus Scheidgen committed
236
        assert task in tasks, 'task %s must be one of the classes tasks %s' % (task, str(tasks))  # pylint: disable=E1135
237
        if self.current_task is None:
Markus Scheidgen's avatar
Markus Scheidgen committed
238
            assert task == tasks[0], "process has to start with first task"  # pylint: disable=E1136
239
240
241
242
        else:
            assert tasks.index(task) == tasks.index(self.current_task) + 1, \
                "tasks must be processed in the right order"

243
        if self.tasks_status == FAILURE:
244
245
            return False

246
        if self.tasks_status == PENDING:
247
            assert self.current_task is None
Markus Scheidgen's avatar
Markus Scheidgen committed
248
            assert task == tasks[0]  # pylint: disable=E1136
249
            self.tasks_status = RUNNING
Markus Scheidgen's avatar
Markus Scheidgen committed
250
            self.current_task = task
251
            self.get_logger().info('started process')
Markus Scheidgen's avatar
Markus Scheidgen committed
252
253
        else:
            self.current_task = task
254
            self.get_logger().info('successfully completed task')
255
256
257
258

        self.save()
        return True

259
    def _complete(self):
260
261
262
263
        if self.tasks_status != FAILURE:
            assert self.tasks_status == RUNNING, 'Can only complete a running process, process is %s' % self.tasks_status
            self.tasks_status = SUCCESS
            self.complete_time = datetime.now()
264
            self.save()
265
            self.get_logger().info('completed process')
266

267
268
    def block_until_complete(self, interval=0.01):
        """
Markus Scheidgen's avatar
Typos.    
Markus Scheidgen committed
269
270
        Reloads the process constantly until it sees a completed process. Should be
        used with care as it can block indefinitely. Just intended for testing purposes.
271
        """
272
        while self.tasks_running:
273
274
275
276
277
278
279
280
281
282
            time.sleep(interval)
            self.reload()


class InvalidChordUsage(Exception): pass


class Chord(Proc):
    """
    A special Proc base class that manages a chord of child processes. It saves some
Markus Scheidgen's avatar
Typos.    
Markus Scheidgen committed
283
    additional state to track child processes and provides methods to control that
284
285
    state.

Markus Scheidgen's avatar
Typos.    
Markus Scheidgen committed
286
    It uses a counter approach with atomic updates to track the number of processed
287
288
    children.

Markus Scheidgen's avatar
Typos.    
Markus Scheidgen committed
289
    TODO the joined attribute is not strictly necessary and only serves debugging purposes.
290
291
292
293
294
    Maybe it should be removed, since it also requires another save.

    TODO it is vital that sub classes and children don't miss any calls. This might
    not be practical, because in reality processes might even fail to fail.

295
296
297
298
    TODO in the current upload processing, the join functionality is not strictly necessary.
    Nothing is done after join. We only need it to report the upload completed on API
    request. We could check the join condition on each of thise API queries.

299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
    Attributes:
        total_children (int): the number of spawed children, -1 denotes that number was not
            saved yet
        completed_children (int): the number of completed child procs
        joined (bool): true if all children are completed and the join method was already called
    """
    total_children = IntField(default=-1)
    completed_children = IntField(default=0)
    joined = BooleanField(default=False)

    meta = {
        'abstract': True
    }

    def spwaned_childred(self, total_children=1):
        """
        Subclasses must call this method after all childred have been spawned.

        Arguments:
            total_children (int): the number of spawned children
        """
        self.total_children = total_children
321
        self.modify(total_children=self.total_children)
322
323
324
        self._check_join(children=0)

    def completed_child(self):
Markus Scheidgen's avatar
Typos.    
Markus Scheidgen committed
325
        """ Children must call this, when they completed processing. """
326
327
328
329
330
331
332
333
334
        self._check_join(children=1)

    def _check_join(self, children):
        # incr the counter and get reference values atomically
        completed_children, others = self.incr_counter(
            'completed_children', children, ['total_children', 'joined'])
        total_children, joined = others

        self.get_logger().debug(
335
            'check for join', total_children=total_children,
336
337
338
339
340
341
342
            completed_children=completed_children, joined=joined)

        # check the join condition and raise errors if chord is in bad state
        if completed_children == total_children:
            if not joined:
                self.join()
                self.joined = True
343
                self.modify(joined=self.joined)
344
                self.get_logger().debug('chord is joined')
345
            else:
346
                raise InvalidChordUsage('chord cannot be joined twice.')
347
        elif completed_children > total_children and total_children != -1:
348
            raise InvalidChordUsage('chord counter is out of limits.')
349
350
351
352
353

    def join(self):
        """ Subclasses might overwrite to do something after all children have completed. """
        pass

354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
    def incr_counter(self, field, value=1, other_fields=None):
        """
        Atomically increases the given field by value and return the new value.
        Optionally return also other values from the updated object to avoid
        reloads.

        Arguments:
            field: the name of the field to increment, must be a :class:`IntField`
            value: the value to increment the field by, default is 1
            other_fields: an optional list of field names that should also be returned

        Returns:
            either the value of the updated field, or a tuple with this value and a list
            of value for the given other fields
        """
        # use a primitive but atomic pymongo call
        updated_raw = self._get_collection().find_one_and_update(
            {'_id': self.id},
            {'$inc': {field: value}},
            return_document=ReturnDocument.AFTER)

        if updated_raw is None:
            raise KeyError('object does not exist, was probaly not yet written to db')

        if other_fields is None:
            return updated_raw[field]
        else:
            return updated_raw[field], [updated_raw[field] for field in other_fields]

383

384
def task(func):
385
    """
Markus Scheidgen's avatar
Typos.    
Markus Scheidgen committed
386
    The decorator for tasks that will be wrapped in exception handling that will fail the process.
387
388
389
390
    The task methods of a :class:`Proc` class/document comprise a sequence
    (order of methods in class namespace) of tasks. Tasks must be executed in that order.
    Completion of the last task, will put the :class:`Proc` instance into the
    SUCCESS state. Calling the first task will put it into RUNNING state. Tasks will
Markus Scheidgen's avatar
Typos.    
Markus Scheidgen committed
391
    only be executed, if the process has not yet reached FAILURE state.
392
    """
393
    def wrapper(self, *args, **kwargs):
394
        if self.tasks_status == FAILURE:
395
            return
396

397
        self._continue_with(func.__name__)
398
        try:
399
400
401
            func(self, *args, **kwargs)
        except Exception as e:
            self.fail(e)
402

403
        if self.__class__.tasks[-1] == self.current_task and self.tasks_running:
404
            self._complete()
405

406
407
408
    setattr(wrapper, '__task_name', func.__name__)
    wrapper.__name__ = func.__name__
    return wrapper
409
410


411
412
413
414
415
def all_subclasses(cls):
    """ Helper method to calculate set of all subclasses of a given class. """
    return set(cls.__subclasses__()).union(
        [s for c in cls.__subclasses__() for s in all_subclasses(c)])

Markus Scheidgen's avatar
Markus Scheidgen committed
416

417
418
419
420
all_proc_cls = {cls.__name__: cls for cls in all_subclasses(Proc)}
""" Name dictionary for all Proc classes. """


421
@app.task(bind=True, ignore_results=True, max_retries=3, acks_late=True)
422
def proc_task(task, cls_name, self_id, func_attr):
423
424
425
426
427
    """
    The celery task that is used to execute async process functions.
    It ignores results, since all results are handled via the self document.
    It retries for 3 times with a countdown of 3 on missing 'selfs', since this
    might happen in sharded, distributed mongo setups where the object might not
Markus Scheidgen's avatar
Typos.    
Markus Scheidgen committed
428
    have yet been propagated and therefore appear missing.
429
    """
Markus Scheidgen's avatar
Markus Scheidgen committed
430
    logger = utils.get_logger(__name__, cls=cls_name, id=self_id, func=func_attr)
431

432
    # get the process class
433
    logger.debug('received process function call')
434
435
436
437
438
439
440
    global all_proc_cls
    cls = all_proc_cls.get(cls_name, None)
    if cls is None:
        # refind all Proc classes, since more modules might have been imported by now
        all_proc_cls = {cls.__name__: cls for cls in all_subclasses(Proc)}
        cls = all_proc_cls.get(cls_name, None)

441
    if cls is None:
442
        logger.critical('document not a subcass of Proc')
Markus Scheidgen's avatar
Markus Scheidgen committed
443
        raise ProcNotRegistered('document %s not a subclass of Proc' % cls_name)
444

445
    # get the process instance
446
    try:
447
448
449
450
451
452
453
454
455
        try:
            self = cls.get(self_id)
        except KeyError as e:
            logger.warning('called object is missing')
            raise task.retry(exc=e, countdown=3)
    except KeyError:
        logger.critical('called object is missing, retries exeeded')

    logger = self.get_logger()
456

457
    # get the process function
458
459
    func = getattr(self, func_attr, None)
    if func is None:
460
        logger.error('called function not a function of proc class')
461
        self.fail('called function %s is not a function of proc class %s' % (func_attr, cls_name))
462
463
        self.process_status = PROCESS_COMPLETED
        self.save()
464
465
        return

466
    # unwrap the process decorator
467
468
    func = getattr(func, '__process_unwrapped', None)
    if func is None:
469
        logger.error('called function was not decorated with @process')
Markus Scheidgen's avatar
Markus Scheidgen committed
470
        self.fail('called function %s was not decorated with @process' % func_attr)
471
472
        self.process_status = PROCESS_COMPLETED
        self.save()
473
474
        return

475
476
477
478
479
480
481
482
483
    # check requeued task after catastrophic failure, e.g. segfault
    if self._celery_task_id is not None:
        if self._celery_task_id == task.request.id and task.request.retries == 0:
            self.fail('task failed catastrophically, probably with sys.exit or segfault')

    if self._celery_task_id != task.request.id:
        self._celery_task_id = task.request.id
        self.save()

484
    # call the process function
485
    deleted = False
486
    try:
487
488
        self.process_status = PROCESS_RUNNING
        deleted = func(self)
489
490
    except Exception as e:
        self.fail(e)
491
492
493
494
    finally:
        if deleted is None or not deleted:
            self.process_status = PROCESS_COMPLETED
            self.save()
495
496


497
def process(func):
498
499
500
501
    """
    The decorator for process functions that will be called async via celery.
    All calls to the decorated method will result in celery task requests.
    To transfer state, the instance will be saved to the database and loading on
502
503
504
    the celery task worker. Process methods can call other (process) functions/methods on
    other :class:`Proc` instances. Each :class:`Proc` instance can only run one
    asny process at a time.
505
    """
506
507
    def wrapper(self, *args, **kwargs):
        assert len(args) == 0 and len(kwargs) == 0, 'process functions must not have arguments'
508
509
510
511
512
        if self.process_running:
            raise ProcessAlreadyRunning

        self.current_process = func.__name__
        self.process_status = PROCESS_CALLED
513
514
        self.save()

515
        self_id = self.id.__str__()
516
        cls_name = self.__class__.__name__
517

518
519
        logger = utils.get_logger(__name__, cls=cls_name, id=self_id, func=func.__name__)
        logger.debug('calling process function')
520
        return proc_task.s(cls_name, self_id, func.__name__).delay()
521

522
523
524
525
526
    task = getattr(func, '__task_name', None)
    if task is not None:
        setattr(wrapper, '__task_name', task)
    wrapper.__name__ = func.__name__
    setattr(wrapper, '__process_unwrapped', func)
527

528
    return wrapper