base.py 18.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2018 Markus Scheidgen
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an"AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Markus Scheidgen's avatar
Markus Scheidgen committed
15
from typing import List, Any
16
import logging
17
import time
18
import os
19
20
from celery import Celery, Task
from celery.worker.request import Request
21
22
from celery.signals import after_setup_task_logger, after_setup_logger, worker_process_init, \
    celeryd_after_setup
23
from celery.utils import worker_direct
24
from celery.exceptions import SoftTimeLimitExceeded
25
26
from billiard.exceptions import WorkerLostError
from mongoengine import Document, StringField, ListField, DateTimeField, ValidationError
27
28
29
from mongoengine.connection import MongoEngineConnectionError
from mongoengine.base.metaclasses import TopLevelDocumentMetaclass
from datetime import datetime
30

31
from nomad import config, utils, infrastructure
32
33
import nomad.patch  # pylint: disable=unused-import

34

35
if config.logstash.enabled:
Markus Scheidgen's avatar
Markus Scheidgen committed
36
37
    utils.configure_logging()

38
39
40
41
42
43
44
    def initialize_logstash(logger=None, loglevel=logging.DEBUG, **kwargs):
        utils.add_logstash_handler(logger)
        return logger

    after_setup_task_logger.connect(initialize_logstash)
    after_setup_logger.connect(initialize_logstash)

45
46
47
48

@worker_process_init.connect
def setup(**kwargs):
    infrastructure.setup()
Markus Scheidgen's avatar
Markus Scheidgen committed
49
50
    utils.get_logger(__name__).info(
        'celery configured with acks_late=%s' % str(config.celery.acks_late))
51

52

53
54
55
56
57
58
59
60
61
worker_hostname = None


@celeryd_after_setup.connect
def capture_worker_name(sender, instance, **kwargs):
    global worker_hostname
    worker_hostname = sender


62
app = Celery('nomad.processing', broker=config.rabbitmq_url())
Markus Scheidgen's avatar
Markus Scheidgen committed
63
app.conf.update(worker_hijack_root_logger=False)
64
app.conf.update(worker_max_memory_per_child=config.celery.max_memory)
65
66
if config.celery.routing == config.CELERY_WORKER_ROUTING:
    app.conf.update(worker_direct=True)
67

68
app.conf.task_queue_max_priority = 10
69

70
CREATED = 'CREATED'
71
72
73
74
75
PENDING = 'PENDING'
RUNNING = 'RUNNING'
FAILURE = 'FAILURE'
SUCCESS = 'SUCCESS'

76
PROCESS_CALLED = 'CALLED'
77
78
79
PROCESS_RUNNING = 'RUNNING'
PROCESS_COMPLETED = 'COMPLETED'

80

81
82
83
class InvalidId(Exception): pass


Markus Scheidgen's avatar
Markus Scheidgen committed
84
class ProcNotRegistered(Exception): pass
85
86


87
88
89
class ProcessAlreadyRunning(Exception): pass


90
91
92
class ProcObjectDoesNotExist(Exception): pass


93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
class ProcMetaclass(TopLevelDocumentMetaclass):
    def __new__(cls, name, bases, attrs):
        cls = super().__new__(cls, name, bases, attrs)

        tasks = []
        setattr(cls, 'tasks', tasks)

        for name, attr in attrs.items():
            task = getattr(attr, '__task_name', None)
            if task is not None and task not in tasks:
                tasks.append(task)

        return cls


class Proc(Document, metaclass=ProcMetaclass):
    """
110
    Base class for objects that are involved in processing and need persistent processing
111
112
    state.

113
114
115
116
117
118
119
120
121
122
123
    It solves two issues. First, distributed operation (via celery) and second keeping
    state of a chain of potentially failing processing tasks. Both are controlled via
    decorators @process and @task. Subclasses should use these decorators on their
    methods. Parameters are not supported for decorated functions. Use fields on the
    document instead.

    Processing state will be persistet at appropriate
    times and must not be persistet manually. All attributes are stored to mongodb.

    Possible processing states are PENDING, RUNNING, FAILURE, and SUCCESS.

124
125
    Attributes:
        current_task: the currently running or last completed task
126
        tasks_status: the overall status of the processing
127
128
129
130
131
        errors: a list of errors that happened during processing. Error fail a processing
            run
        warnings: a list of warnings that happened during processing. Warnings do not
            fail a processing run
        create_time: the time of creation (not the start of processing)
132
133
134
        complete_time: the time that processing completed (successfully or not)
        current_process: the currently or last run asyncronous process
        process_status: the status of the currently or last run asyncronous process
135
136
    """

Markus Scheidgen's avatar
Markus Scheidgen committed
137
    meta: Any = {
138
139
140
141
142
143
        'abstract': True,
    }

    tasks: List[str] = None
    """ the ordered list of tasks that comprise a processing run """

144
    current_task = StringField(default=None)
145
146
147
    tasks_status = StringField(default=CREATED)
    create_time = DateTimeField(required=True)
    complete_time = DateTimeField()
148
149
150

    errors = ListField(StringField())
    warnings = ListField(StringField())
151

152
153
    current_process = StringField(default=None)
    process_status = StringField(default=None)
154

155
156
    worker_hostname = StringField(default=None)
    celery_task_id = StringField(default=None)
157

158
    @property
159
    def tasks_running(self) -> bool:
160
        """ Returns True of the process has failed or succeeded. """
161
        return self.tasks_status not in [SUCCESS, FAILURE]
162
163
164
165
166

    @property
    def process_running(self) -> bool:
        """ Returns True of an asynchrounous process is currently running. """
        return self.process_status is not None and self.process_status != PROCESS_COMPLETED
167

Markus Scheidgen's avatar
Markus Scheidgen committed
168
169
    def get_logger(self):
        return utils.get_logger(
Markus Scheidgen's avatar
Markus Scheidgen committed
170
171
            'nomad.processing', task=self.current_task, proc=self.__class__.__name__,
            process=self.current_process, process_status=self.process_status,
172
            tasks_status=self.tasks_status)
Markus Scheidgen's avatar
Markus Scheidgen committed
173

174
175
    @classmethod
    def create(cls, **kwargs):
176
        """ Factory method that must be used instead of regular constructor. """
177
        assert 'tasks_status' not in kwargs, \
178
            """ do not set the status manually, its managed """
179

180
        kwargs.setdefault('create_time', datetime.utcnow())
181
        self = cls(**kwargs)
182
183
184
185
        if len(cls.tasks) == 0:
            self.tasks_status = SUCCESS
        else:
            self.tasks_status = PENDING if self.current_task is None else RUNNING
186
        self.save()
187

188
        return self
189

190
191
    def reset(self):
        """ Resets the task chain. Assumes there no current running process. """
192
193
        assert not self.process_running

194
195
196
197
        self.current_task = None
        self.tasks_status = PENDING
        self.errors = []
        self.warnings = []
198
        self.worker_hostname = None
199

200
    @classmethod
Markus Scheidgen's avatar
Markus Scheidgen committed
201
    def get_by_id(cls, id: str, id_field: str):
202
        try:
Markus Scheidgen's avatar
Markus Scheidgen committed
203
            obj = cls.objects(**{id_field: id}).first()
204
        except ValidationError as e:
Markus Scheidgen's avatar
Markus Scheidgen committed
205
            raise InvalidId('%s is not a valid id' % id)
206
207
        except MongoEngineConnectionError as e:
            raise e
208

209
        if obj is None:
Markus Scheidgen's avatar
Markus Scheidgen committed
210
            raise KeyError('%s with id %s does not exist' % (cls.__name__, id))
211
212
213

        return obj

Markus Scheidgen's avatar
Markus Scheidgen committed
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
    @classmethod
    def get(cls, obj_id):
        return cls.get_by_id(str(obj_id), 'id')

    @staticmethod
    def log(logger, log_level, msg, **kwargs):
        # TODO there seems to be a bug in structlog, cannot use logger.log
        if log_level == logging.ERROR:
            logger.error(msg, **kwargs)
        elif log_level == logging.WARNING:
            logger.warning(msg, **kwargs)
        elif log_level == logging.INFO:
            logger.info(msg, **kwargs)
        elif log_level == logging.DEBUG:
            logger.debug(msg, **kwargs)
        else:
            logger.critical(msg, **kwargs)

    def fail(self, *errors, log_level=logging.ERROR, **kwargs):
233
        """ Allows to fail the process. Takes strings or exceptions as args. """
Markus Scheidgen's avatar
Markus Scheidgen committed
234
        assert self.process_running or self.tasks_running, 'Cannot fail a completed process.'
Markus Scheidgen's avatar
Markus Scheidgen committed
235
236

        failed_with_exception = False
237

238
        self.tasks_status = FAILURE
Markus Scheidgen's avatar
Markus Scheidgen committed
239
240
241
242
243

        logger = self.get_logger(**kwargs)
        for error in errors:
            if isinstance(error, Exception):
                failed_with_exception = True
244
245
246
                Proc.log(
                    logger, log_level, 'task failed with exception',
                    exc_info=error, error=str(error))
Markus Scheidgen's avatar
Markus Scheidgen committed
247

248
        self.errors = [str(error) for error in errors]
249
        self.complete_time = datetime.utcnow()
250

Markus Scheidgen's avatar
Markus Scheidgen committed
251
252
        if not failed_with_exception:
            errors_str = "; ".join([str(error) for error in errors])
253
            Proc.log(logger, log_level, 'task failed', errors=errors_str)
Markus Scheidgen's avatar
Markus Scheidgen committed
254

255
        logger.info('process failed')
Markus Scheidgen's avatar
Markus Scheidgen committed
256

257
258
        self.save()

Markus Scheidgen's avatar
Markus Scheidgen committed
259
    def warning(self, *warnings, log_level=logging.WARNING, **kwargs):
260
        """ Allows to save warnings. Takes strings or exceptions as args. """
Markus Scheidgen's avatar
Markus Scheidgen committed
261
        assert self.process_running or self.tasks_running
262

Markus Scheidgen's avatar
Markus Scheidgen committed
263
264
        logger = self.get_logger(**kwargs)

265
        for warning in warnings:
Markus Scheidgen's avatar
Markus Scheidgen committed
266
267
            warning = str(warning)
            self.warnings.append(warning)
268
            Proc.log(logger, log_level, 'task with warning', warning=warning)
269

270
    def _continue_with(self, task):
271
        tasks = self.__class__.tasks
Markus Scheidgen's avatar
Markus Scheidgen committed
272
        assert task in tasks, 'task %s must be one of the classes tasks %s' % (task, str(tasks))  # pylint: disable=E1135
273
        if self.current_task is None:
274
            assert task == tasks[0], "process has to start with first task %s" % tasks[0]  # pylint: disable=E1136
275
276
277
278
        elif tasks.index(task) <= tasks.index(self.current_task):
            # task is repeated, probably the celery task of the process was reschedule
            # due to prior worker failure
            self.current_task = task
279
            self.get_logger().error('task is re-run')
280
281
            self.save()
            return True
282
283
284
285
        else:
            assert tasks.index(task) == tasks.index(self.current_task) + 1, \
                "tasks must be processed in the right order"

286
        if self.tasks_status == FAILURE:
287
288
            return False

289
        if self.tasks_status == PENDING:
290
            assert self.current_task is None
Markus Scheidgen's avatar
Markus Scheidgen committed
291
            assert task == tasks[0]  # pylint: disable=E1136
292
            self.tasks_status = RUNNING
Markus Scheidgen's avatar
Markus Scheidgen committed
293
            self.current_task = task
294
            self.get_logger().info('started process')
Markus Scheidgen's avatar
Markus Scheidgen committed
295
296
        else:
            self.current_task = task
297
            self.get_logger().info('successfully completed task')
298
299
300
301

        self.save()
        return True

302
    def _complete(self):
303
304
305
        if self.tasks_status != FAILURE:
            assert self.tasks_status == RUNNING, 'Can only complete a running process, process is %s' % self.tasks_status
            self.tasks_status = SUCCESS
306
            self.complete_time = datetime.utcnow()
307
            self.on_tasks_complete()
308
            self.save()
309
            self.get_logger().info('completed process')
310

311
312
313
314
315
316
317
318
    def on_tasks_complete(self):
        """ Callback that is called when the list of task are completed """
        pass

    def on_process_complete(self, process_name):
        """ Callback that is called when the corrent process completed """
        pass

319
320
    def block_until_complete(self, interval=0.01):
        """
Markus Scheidgen's avatar
Typos.    
Markus Scheidgen committed
321
322
        Reloads the process constantly until it sees a completed process. Should be
        used with care as it can block indefinitely. Just intended for testing purposes.
323
        """
324
        while self.tasks_running or self.process_running:
325
326
327
            time.sleep(interval)
            self.reload()

328
    def __str__(self):
Markus Scheidgen's avatar
Markus Scheidgen committed
329
        return 'proc celery_task_id=%s worker_hostname=%s' % (self.celery_task_id, self.worker_hostname)
330

331

332
def task(func):
333
    """
Markus Scheidgen's avatar
Typos.    
Markus Scheidgen committed
334
    The decorator for tasks that will be wrapped in exception handling that will fail the process.
335
336
337
338
    The task methods of a :class:`Proc` class/document comprise a sequence
    (order of methods in class namespace) of tasks. Tasks must be executed in that order.
    Completion of the last task, will put the :class:`Proc` instance into the
    SUCCESS state. Calling the first task will put it into RUNNING state. Tasks will
Markus Scheidgen's avatar
Typos.    
Markus Scheidgen committed
339
    only be executed, if the process has not yet reached FAILURE state.
340
    """
341
    def wrapper(self, *args, **kwargs):
342
        if self.tasks_status == FAILURE:
343
            return
344

345
        self._continue_with(func.__name__)
346
        try:
347
348
349
            func(self, *args, **kwargs)
        except Exception as e:
            self.fail(e)
350

351
        if self.__class__.tasks[-1] == self.current_task and self.tasks_running:
352
            self._complete()
353

354
355
356
    setattr(wrapper, '__task_name', func.__name__)
    wrapper.__name__ = func.__name__
    return wrapper
357
358


359
360
361
362
363
def all_subclasses(cls):
    """ Helper method to calculate set of all subclasses of a given class. """
    return set(cls.__subclasses__()).union(
        [s for c in cls.__subclasses__() for s in all_subclasses(c)])

Markus Scheidgen's avatar
Markus Scheidgen committed
364

365
366
367
368
all_proc_cls = {cls.__name__: cls for cls in all_subclasses(Proc)}
""" Name dictionary for all Proc classes. """


369
class NomadCeleryRequest(Request):
370
    """
371
372
373
374
375
376
377
378
379
380
381
382
383
    A custom celery request class that allows to catch error in the worker main
    thread, which cannot be caught on the worker threads themselves.
    """

    def _fail(self, event, **kwargs):
        args = self._payload[0]
        # this might be run in the worker main thread, which does not have a mongo
        # connection by default
        if infrastructure.mongo_client is None:
            infrastructure.setup_mongo()
        if infrastructure.repository_db is None:
            infrastructure.setup_repository_db()
        proc = unwarp_task(self.task, *args)
384
        proc.fail(event, **kwargs)
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
        proc.process_status = PROCESS_COMPLETED
        proc.on_process_complete(None)
        proc.save()

    def on_timeout(self, soft, timeout):
        if not soft:
            self._fail('task timeout occurred', timeout=timeout)

        super().on_timeout(soft, timeout)

    def on_failure(self, exc_info, send_failed_event=True, return_ok=False):
        if isinstance(exc_info.exception, WorkerLostError):
            self._fail(
                'task failed due to worker lost: %s' % str(exc_info.exception),
                exc_info=exc_info)

        super().on_failure(
            exc_info,
            send_failed_event=send_failed_event,
            return_ok=return_ok
        )


class NomadCeleryTask(Task):
    Request = NomadCeleryRequest


def unwarp_task(task, cls_name, self_id, *args, **kwargs):
    """
    Retrieves the proc object that the given task is executed on from the database.
415
    """
416
    logger = utils.get_logger(__name__, cls=cls_name, id=self_id)
417

418
419
420
421
422
423
424
425
    # get the process class
    global all_proc_cls
    cls = all_proc_cls.get(cls_name, None)
    if cls is None:
        # refind all Proc classes, since more modules might have been imported by now
        all_proc_cls = {cls.__name__: cls for cls in all_subclasses(Proc)}
        cls = all_proc_cls.get(cls_name, None)

426
    if cls is None:
427
        logger.critical('document not a subcass of Proc')
Markus Scheidgen's avatar
Markus Scheidgen committed
428
        raise ProcNotRegistered('document %s not a subclass of Proc' % cls_name)
429

430
    # get the process instance
431
    try:
432
433
434
435
436
437
        try:
            self = cls.get(self_id)
        except KeyError as e:
            logger.warning('called object is missing')
            raise task.retry(exc=e, countdown=3)
    except KeyError:
438
        logger.critical('called object is missing, retries exeeded', proc_id=self_id)
439
440
441
442
443
        raise ProcObjectDoesNotExist()

    return self


444
445
@app.task(
    bind=True, base=NomadCeleryTask, ignore_results=True, max_retries=3,
446
447
    acks_late=config.celery.acks_late, soft_time_limit=config.celery.timeout,
    time_limit=config.celery.timeout + 120)
448
449
450
451
452
453
454
455
456
def proc_task(task, cls_name, self_id, func_attr):
    """
    The celery task that is used to execute async process functions.
    It ignores results, since all results are handled via the self document.
    It retries for 3 times with a countdown of 3 on missing 'selfs', since this
    might happen in sharded, distributed mongo setups where the object might not
    have yet been propagated and therefore appear missing.
    """
    self = unwarp_task(task, cls_name, self_id)
457
458

    logger = self.get_logger()
459
    logger.debug('received process function call')
460

461
    self.worker_hostname = worker_hostname
462
463
    self.celery_task_id = task.request.id

464
    # get the process function
465
466
    func = getattr(self, func_attr, None)
    if func is None:
467
        logger.error('called function not a function of proc class')
468
        self.fail('called function %s is not a function of proc class %s' % (func_attr, cls_name))
469
470
        self.process_status = PROCESS_COMPLETED
        self.save()
471
472
        return

473
    # unwrap the process decorator
474
475
    func = getattr(func, '__process_unwrapped', None)
    if func is None:
476
        logger.error('called function was not decorated with @process')
Markus Scheidgen's avatar
Markus Scheidgen committed
477
        self.fail('called function %s was not decorated with @process' % func_attr)
478
        self.process_status = PROCESS_COMPLETED
479
        self.on_process_complete(None)
480
        self.save()
481
482
        return

483
    # call the process function
484
    deleted = False
485
    try:
486
        self.process_status = PROCESS_RUNNING
487
        os.chdir(config.fs.working_directory)
488
        deleted = func(self)
489
490
491
    except SoftTimeLimitExceeded as e:
        logger.error('exceeded the celery task soft time limit')
        self.fail(e)
492
493
    except Exception as e:
        self.fail(e)
494
495
    except SystemExit as e:
        self.fail(e)
496
497
498
    finally:
        if deleted is None or not deleted:
            self.process_status = PROCESS_COMPLETED
499
            self.on_process_complete(func.__name__)
500
            self.save()
501
502


503
def process(func):
504
505
506
507
    """
    The decorator for process functions that will be called async via celery.
    All calls to the decorated method will result in celery task requests.
    To transfer state, the instance will be saved to the database and loading on
508
509
    the celery task worker. Process methods can call other (process) functions/methods on
    other :class:`Proc` instances. Each :class:`Proc` instance can only run one
510
    any process at a time.
511
    """
512
513
    def wrapper(self, *args, **kwargs):
        assert len(args) == 0 and len(kwargs) == 0, 'process functions must not have arguments'
514
        if self.process_running:
515
            raise ProcessAlreadyRunning('Tried to call a processing function on an already processing process.')
516
517
518

        self.current_process = func.__name__
        self.process_status = PROCESS_CALLED
519
520
        self.save()

521
        self_id = self.id.__str__()
522
        cls_name = self.__class__.__name__
523

524
        queue = None
525
        if config.celery.routing == config.CELERY_WORKER_ROUTING and self.worker_hostname is not None:
526
            queue = worker_direct(self.worker_hostname).name
527

528
529
        priority = config.celery.priorities.get('%s.%s' % (cls_name, func.__name__), 1)

530
        logger = utils.get_logger(__name__, cls=cls_name, id=self_id, func=func.__name__)
531
        logger.debug('calling process function', queue=queue, priority=priority)
532

533
534
535
        return proc_task.apply_async(
            args=[cls_name, self_id, func.__name__],
            queue=queue, priority=priority)
536

537
538
539
540
541
    task = getattr(func, '__task_name', None)
    if task is not None:
        setattr(wrapper, '__task_name', task)
    wrapper.__name__ = func.__name__
    setattr(wrapper, '__process_unwrapped', func)
542

543
    return wrapper