base.py 18.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2018 Markus Scheidgen
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an"AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Markus Scheidgen's avatar
Markus Scheidgen committed
15
from typing import List, Any
16
import logging
17
import time
18
19
from celery import Celery, Task
from celery.worker.request import Request
20
from celery.signals import after_setup_task_logger, after_setup_logger, worker_process_init
21
from celery.utils import worker_direct
22
from celery.exceptions import SoftTimeLimitExceeded
23
24
from billiard.exceptions import WorkerLostError
from mongoengine import Document, StringField, ListField, DateTimeField, ValidationError
25
26
27
from mongoengine.connection import MongoEngineConnectionError
from mongoengine.base.metaclasses import TopLevelDocumentMetaclass
from datetime import datetime
28

29
from nomad import config, utils, infrastructure
30
31
import nomad.patch  # pylint: disable=unused-import

32

33
if config.logstash.enabled:
Markus Scheidgen's avatar
Markus Scheidgen committed
34
35
    utils.configure_logging()

36
37
38
39
40
41
42
    def initialize_logstash(logger=None, loglevel=logging.DEBUG, **kwargs):
        utils.add_logstash_handler(logger)
        return logger

    after_setup_task_logger.connect(initialize_logstash)
    after_setup_logger.connect(initialize_logstash)

43
44
45
46

@worker_process_init.connect
def setup(**kwargs):
    infrastructure.setup()
Markus Scheidgen's avatar
Markus Scheidgen committed
47
48
    utils.get_logger(__name__).info(
        'celery configured with acks_late=%s' % str(config.celery.acks_late))
49

50

51
app = Celery('nomad.processing', broker=config.rabbitmq_url())
Markus Scheidgen's avatar
Markus Scheidgen committed
52
app.conf.update(worker_hijack_root_logger=False)
53
app.conf.update(worker_max_memory_per_child=config.celery.max_memory)
54
55
if config.celery.routing == config.CELERY_WORKER_ROUTING:
    app.conf.update(worker_direct=True)
56

57
58
app.conf.task_queues = config.celery.task_queues

59
CREATED = 'CREATED'
60
61
62
63
64
PENDING = 'PENDING'
RUNNING = 'RUNNING'
FAILURE = 'FAILURE'
SUCCESS = 'SUCCESS'

65
PROCESS_CALLED = 'CALLED'
66
67
68
PROCESS_RUNNING = 'RUNNING'
PROCESS_COMPLETED = 'COMPLETED'

69

70
71
72
class InvalidId(Exception): pass


Markus Scheidgen's avatar
Markus Scheidgen committed
73
class ProcNotRegistered(Exception): pass
74
75


76
77
78
class ProcessAlreadyRunning(Exception): pass


79
80
81
class ProcObjectDoesNotExist(Exception): pass


82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
class ProcMetaclass(TopLevelDocumentMetaclass):
    def __new__(cls, name, bases, attrs):
        cls = super().__new__(cls, name, bases, attrs)

        tasks = []
        setattr(cls, 'tasks', tasks)

        for name, attr in attrs.items():
            task = getattr(attr, '__task_name', None)
            if task is not None and task not in tasks:
                tasks.append(task)

        return cls


class Proc(Document, metaclass=ProcMetaclass):
    """
99
    Base class for objects that are involved in processing and need persistent processing
100
101
    state.

102
103
104
105
106
107
108
109
110
111
112
    It solves two issues. First, distributed operation (via celery) and second keeping
    state of a chain of potentially failing processing tasks. Both are controlled via
    decorators @process and @task. Subclasses should use these decorators on their
    methods. Parameters are not supported for decorated functions. Use fields on the
    document instead.

    Processing state will be persistet at appropriate
    times and must not be persistet manually. All attributes are stored to mongodb.

    Possible processing states are PENDING, RUNNING, FAILURE, and SUCCESS.

113
114
    Attributes:
        current_task: the currently running or last completed task
115
        tasks_status: the overall status of the processing
116
117
118
119
120
        errors: a list of errors that happened during processing. Error fail a processing
            run
        warnings: a list of warnings that happened during processing. Warnings do not
            fail a processing run
        create_time: the time of creation (not the start of processing)
121
122
123
        complete_time: the time that processing completed (successfully or not)
        current_process: the currently or last run asyncronous process
        process_status: the status of the currently or last run asyncronous process
124
125
    """

Markus Scheidgen's avatar
Markus Scheidgen committed
126
    meta: Any = {
127
128
129
130
131
132
        'abstract': True,
    }

    tasks: List[str] = None
    """ the ordered list of tasks that comprise a processing run """

133
    current_task = StringField(default=None)
134
135
136
    tasks_status = StringField(default=CREATED)
    create_time = DateTimeField(required=True)
    complete_time = DateTimeField()
137
138
139

    errors = ListField(StringField())
    warnings = ListField(StringField())
140

141
142
    current_process = StringField(default=None)
    process_status = StringField(default=None)
143

144
145
    worker_hostname = StringField(default=None)
    celery_task_id = StringField(default=None)
146

147
    @property
148
    def tasks_running(self) -> bool:
149
        """ Returns True of the process has failed or succeeded. """
150
        return self.tasks_status not in [SUCCESS, FAILURE]
151
152
153
154
155

    @property
    def process_running(self) -> bool:
        """ Returns True of an asynchrounous process is currently running. """
        return self.process_status is not None and self.process_status != PROCESS_COMPLETED
156

Markus Scheidgen's avatar
Markus Scheidgen committed
157
158
    def get_logger(self):
        return utils.get_logger(
Markus Scheidgen's avatar
Markus Scheidgen committed
159
160
            'nomad.processing', task=self.current_task, proc=self.__class__.__name__,
            process=self.current_process, process_status=self.process_status,
161
            tasks_status=self.tasks_status)
Markus Scheidgen's avatar
Markus Scheidgen committed
162

163
164
    @classmethod
    def create(cls, **kwargs):
165
        """ Factory method that must be used instead of regular constructor. """
166
        assert cls.tasks is not None and len(cls.tasks) > 0, \
Markus Scheidgen's avatar
Typos.    
Markus Scheidgen committed
167
            """ the class attribute tasks must be overwritten with an actual list """
168
        assert 'tasks_status' not in kwargs, \
169
            """ do not set the status manually, its managed """
170

171
172
        kwargs.setdefault('create_time', datetime.now())
        self = cls(**kwargs)
173
        self.tasks_status = PENDING if self.current_task is None else RUNNING
174
        self.save()
175

176
        return self
177

178
    @classmethod
Markus Scheidgen's avatar
Markus Scheidgen committed
179
    def get_by_id(cls, id: str, id_field: str):
180
        try:
Markus Scheidgen's avatar
Markus Scheidgen committed
181
            obj = cls.objects(**{id_field: id}).first()
182
        except ValidationError as e:
Markus Scheidgen's avatar
Markus Scheidgen committed
183
            raise InvalidId('%s is not a valid id' % id)
184
185
        except MongoEngineConnectionError as e:
            raise e
186

187
        if obj is None:
Markus Scheidgen's avatar
Markus Scheidgen committed
188
            raise KeyError('%s with id %s does not exist' % (cls.__name__, id))
189
190
191

        return obj

Markus Scheidgen's avatar
Markus Scheidgen committed
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    @classmethod
    def get(cls, obj_id):
        return cls.get_by_id(str(obj_id), 'id')

    @staticmethod
    def log(logger, log_level, msg, **kwargs):
        # TODO there seems to be a bug in structlog, cannot use logger.log
        if log_level == logging.ERROR:
            logger.error(msg, **kwargs)
        elif log_level == logging.WARNING:
            logger.warning(msg, **kwargs)
        elif log_level == logging.INFO:
            logger.info(msg, **kwargs)
        elif log_level == logging.DEBUG:
            logger.debug(msg, **kwargs)
        else:
            logger.critical(msg, **kwargs)

    def fail(self, *errors, log_level=logging.ERROR, **kwargs):
211
        """ Allows to fail the process. Takes strings or exceptions as args. """
Markus Scheidgen's avatar
Markus Scheidgen committed
212
        assert self.process_running or self.tasks_running, 'Cannot fail a completed process.'
Markus Scheidgen's avatar
Markus Scheidgen committed
213
214

        failed_with_exception = False
215

216
        self.tasks_status = FAILURE
Markus Scheidgen's avatar
Markus Scheidgen committed
217
218
219
220
221

        logger = self.get_logger(**kwargs)
        for error in errors:
            if isinstance(error, Exception):
                failed_with_exception = True
222
223
224
                Proc.log(
                    logger, log_level, 'task failed with exception',
                    exc_info=error, error=str(error))
Markus Scheidgen's avatar
Markus Scheidgen committed
225

226
        self.errors = [str(error) for error in errors]
227
228
        self.complete_time = datetime.now()

Markus Scheidgen's avatar
Markus Scheidgen committed
229
230
        if not failed_with_exception:
            errors_str = "; ".join([str(error) for error in errors])
231
            Proc.log(logger, log_level, 'task failed', errors=errors_str)
Markus Scheidgen's avatar
Markus Scheidgen committed
232

233
        logger.info('process failed')
Markus Scheidgen's avatar
Markus Scheidgen committed
234

235
236
        self.save()

Markus Scheidgen's avatar
Markus Scheidgen committed
237
    def warning(self, *warnings, log_level=logging.WARNING, **kwargs):
238
        """ Allows to save warnings. Takes strings or exceptions as args. """
Markus Scheidgen's avatar
Markus Scheidgen committed
239
        assert self.process_running or self.tasks_running
240

Markus Scheidgen's avatar
Markus Scheidgen committed
241
242
        logger = self.get_logger(**kwargs)

243
        for warning in warnings:
Markus Scheidgen's avatar
Markus Scheidgen committed
244
245
            warning = str(warning)
            self.warnings.append(warning)
246
            Proc.log(logger, log_level, 'task with warning', warning=warning)
247

248
    def _continue_with(self, task):
249
        tasks = self.__class__.tasks
Markus Scheidgen's avatar
Markus Scheidgen committed
250
        assert task in tasks, 'task %s must be one of the classes tasks %s' % (task, str(tasks))  # pylint: disable=E1135
251
        if self.current_task is None:
Markus Scheidgen's avatar
Markus Scheidgen committed
252
            assert task == tasks[0], "process has to start with first task"  # pylint: disable=E1136
253
254
255
256
257
258
259
        elif tasks.index(task) <= tasks.index(self.current_task):
            # task is repeated, probably the celery task of the process was reschedule
            # due to prior worker failure
            self.current_task = task
            self.get_logger().warning('task is re-run')
            self.save()
            return True
260
261
262
263
        else:
            assert tasks.index(task) == tasks.index(self.current_task) + 1, \
                "tasks must be processed in the right order"

264
        if self.tasks_status == FAILURE:
265
266
            return False

267
        if self.tasks_status == PENDING:
268
            assert self.current_task is None
Markus Scheidgen's avatar
Markus Scheidgen committed
269
            assert task == tasks[0]  # pylint: disable=E1136
270
            self.tasks_status = RUNNING
Markus Scheidgen's avatar
Markus Scheidgen committed
271
            self.current_task = task
272
            self.get_logger().info('started process')
Markus Scheidgen's avatar
Markus Scheidgen committed
273
274
        else:
            self.current_task = task
275
            self.get_logger().info('successfully completed task')
276
277
278
279

        self.save()
        return True

280
    def _complete(self):
281
282
283
284
        if self.tasks_status != FAILURE:
            assert self.tasks_status == RUNNING, 'Can only complete a running process, process is %s' % self.tasks_status
            self.tasks_status = SUCCESS
            self.complete_time = datetime.now()
285
            self.on_tasks_complete()
286
            self.save()
287
            self.get_logger().info('completed process')
288

289
290
291
292
293
294
295
296
    def on_tasks_complete(self):
        """ Callback that is called when the list of task are completed """
        pass

    def on_process_complete(self, process_name):
        """ Callback that is called when the corrent process completed """
        pass

297
298
    def block_until_complete(self, interval=0.01):
        """
Markus Scheidgen's avatar
Typos.    
Markus Scheidgen committed
299
300
        Reloads the process constantly until it sees a completed process. Should be
        used with care as it can block indefinitely. Just intended for testing purposes.
301
        """
302
        while self.tasks_running or self.process_running:
303
304
305
            time.sleep(interval)
            self.reload()

306
    def __str__(self):
Markus Scheidgen's avatar
Markus Scheidgen committed
307
        return 'proc celery_task_id=%s worker_hostname=%s' % (self.celery_task_id, self.worker_hostname)
308

309

310
def task(func):
311
    """
Markus Scheidgen's avatar
Typos.    
Markus Scheidgen committed
312
    The decorator for tasks that will be wrapped in exception handling that will fail the process.
313
314
315
316
    The task methods of a :class:`Proc` class/document comprise a sequence
    (order of methods in class namespace) of tasks. Tasks must be executed in that order.
    Completion of the last task, will put the :class:`Proc` instance into the
    SUCCESS state. Calling the first task will put it into RUNNING state. Tasks will
Markus Scheidgen's avatar
Typos.    
Markus Scheidgen committed
317
    only be executed, if the process has not yet reached FAILURE state.
318
    """
319
    def wrapper(self, *args, **kwargs):
320
        if self.tasks_status == FAILURE:
321
            return
322

323
        self._continue_with(func.__name__)
324
        try:
325
326
327
            func(self, *args, **kwargs)
        except Exception as e:
            self.fail(e)
328

329
        if self.__class__.tasks[-1] == self.current_task and self.tasks_running:
330
            self._complete()
331

332
333
334
    setattr(wrapper, '__task_name', func.__name__)
    wrapper.__name__ = func.__name__
    return wrapper
335
336


337
338
339
340
341
def all_subclasses(cls):
    """ Helper method to calculate set of all subclasses of a given class. """
    return set(cls.__subclasses__()).union(
        [s for c in cls.__subclasses__() for s in all_subclasses(c)])

Markus Scheidgen's avatar
Markus Scheidgen committed
342

343
344
345
346
all_proc_cls = {cls.__name__: cls for cls in all_subclasses(Proc)}
""" Name dictionary for all Proc classes. """


347
class NomadCeleryRequest(Request):
348
    """
349
350
351
352
353
354
355
356
357
358
359
360
361
    A custom celery request class that allows to catch error in the worker main
    thread, which cannot be caught on the worker threads themselves.
    """

    def _fail(self, event, **kwargs):
        args = self._payload[0]
        # this might be run in the worker main thread, which does not have a mongo
        # connection by default
        if infrastructure.mongo_client is None:
            infrastructure.setup_mongo()
        if infrastructure.repository_db is None:
            infrastructure.setup_repository_db()
        proc = unwarp_task(self.task, *args)
362
        proc.fail(event, **kwargs)
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
        proc.process_status = PROCESS_COMPLETED
        proc.on_process_complete(None)
        proc.save()

    def on_timeout(self, soft, timeout):
        if not soft:
            self._fail('task timeout occurred', timeout=timeout)

        super().on_timeout(soft, timeout)

    def on_failure(self, exc_info, send_failed_event=True, return_ok=False):
        if isinstance(exc_info.exception, WorkerLostError):
            self._fail(
                'task failed due to worker lost: %s' % str(exc_info.exception),
                exc_info=exc_info)

        super().on_failure(
            exc_info,
            send_failed_event=send_failed_event,
            return_ok=return_ok
        )


class NomadCeleryTask(Task):
    Request = NomadCeleryRequest


def unwarp_task(task, cls_name, self_id, *args, **kwargs):
    """
    Retrieves the proc object that the given task is executed on from the database.
393
    """
394
    logger = utils.get_logger(__name__, cls=cls_name, id=self_id)
395

396
397
398
399
400
401
402
403
    # get the process class
    global all_proc_cls
    cls = all_proc_cls.get(cls_name, None)
    if cls is None:
        # refind all Proc classes, since more modules might have been imported by now
        all_proc_cls = {cls.__name__: cls for cls in all_subclasses(Proc)}
        cls = all_proc_cls.get(cls_name, None)

404
    if cls is None:
405
        logger.critical('document not a subcass of Proc')
Markus Scheidgen's avatar
Markus Scheidgen committed
406
        raise ProcNotRegistered('document %s not a subclass of Proc' % cls_name)
407

408
    # get the process instance
409
    try:
410
411
412
413
414
415
416
        try:
            self = cls.get(self_id)
        except KeyError as e:
            logger.warning('called object is missing')
            raise task.retry(exc=e, countdown=3)
    except KeyError:
        logger.critical('called object is missing, retries exeeded')
417
418
419
420
421
        raise ProcObjectDoesNotExist()

    return self


422
423
@app.task(
    bind=True, base=NomadCeleryTask, ignore_results=True, max_retries=3,
424
425
    acks_late=config.celery.acks_late, soft_time_limit=config.celery.timeout,
    time_limit=config.celery.timeout + 120)
426
427
428
429
430
431
432
433
434
def proc_task(task, cls_name, self_id, func_attr):
    """
    The celery task that is used to execute async process functions.
    It ignores results, since all results are handled via the self document.
    It retries for 3 times with a countdown of 3 on missing 'selfs', since this
    might happen in sharded, distributed mongo setups where the object might not
    have yet been propagated and therefore appear missing.
    """
    self = unwarp_task(task, cls_name, self_id)
435
436

    logger = self.get_logger()
437
    logger.debug('received process function call')
438

439
440
441
    self.worker_hostname = task.request.hostname
    self.celery_task_id = task.request.id

442
    # get the process function
443
444
    func = getattr(self, func_attr, None)
    if func is None:
445
        logger.error('called function not a function of proc class')
446
        self.fail('called function %s is not a function of proc class %s' % (func_attr, cls_name))
447
448
        self.process_status = PROCESS_COMPLETED
        self.save()
449
450
        return

451
    # unwrap the process decorator
452
453
    func = getattr(func, '__process_unwrapped', None)
    if func is None:
454
        logger.error('called function was not decorated with @process')
Markus Scheidgen's avatar
Markus Scheidgen committed
455
        self.fail('called function %s was not decorated with @process' % func_attr)
456
        self.process_status = PROCESS_COMPLETED
457
        self.on_process_complete(None)
458
        self.save()
459
460
        return

461
    # call the process function
462
    deleted = False
463
    try:
464
465
        self.process_status = PROCESS_RUNNING
        deleted = func(self)
466
467
468
    except SoftTimeLimitExceeded as e:
        logger.error('exceeded the celery task soft time limit')
        self.fail(e)
469
470
    except Exception as e:
        self.fail(e)
471
472
473
    finally:
        if deleted is None or not deleted:
            self.process_status = PROCESS_COMPLETED
474
            self.on_process_complete(func.__name__)
475
            self.save()
476
477


478
def process(func):
479
480
481
482
    """
    The decorator for process functions that will be called async via celery.
    All calls to the decorated method will result in celery task requests.
    To transfer state, the instance will be saved to the database and loading on
483
484
    the celery task worker. Process methods can call other (process) functions/methods on
    other :class:`Proc` instances. Each :class:`Proc` instance can only run one
485
    any process at a time.
486
    """
487
488
    def wrapper(self, *args, **kwargs):
        assert len(args) == 0 and len(kwargs) == 0, 'process functions must not have arguments'
489
490
491
492
493
        if self.process_running:
            raise ProcessAlreadyRunning

        self.current_process = func.__name__
        self.process_status = PROCESS_CALLED
494
495
        self.save()

496
        self_id = self.id.__str__()
497
        cls_name = self.__class__.__name__
498

499
        queue = getattr(self.__class__, 'queue', None)
500
501
502
503
504
        if config.celery.routing == config.CELERY_WORKER_ROUTING and self.worker_hostname is not None:
            queue = 'celery@%s' % worker_direct(self.worker_hostname).name

        logger = utils.get_logger(__name__, cls=cls_name, id=self_id, func=func.__name__)
        logger.debug('calling process function', queue=queue)
505
506

        return proc_task.apply_async(args=[cls_name, self_id, func.__name__], queue=queue)
507

508
509
510
511
512
    task = getattr(func, '__task_name', None)
    if task is not None:
        setattr(wrapper, '__task_name', task)
    wrapper.__name__ = func.__name__
    setattr(wrapper, '__process_unwrapped', func)
513

514
    return wrapper