Commit 7f0d9424 authored by Markus Scheidgen's avatar Markus Scheidgen
Browse files

Refactored the base celery processing. Made everything resilient agains deadlock, segfaults, etc.

parent b8f6f076
Pipeline #43768 passed with stages
in 21 minutes and 39 seconds
......@@ -30,7 +30,7 @@ FilesConfig = namedtuple(
'FilesConfig', ['uploads_bucket', 'raw_bucket', 'archive_bucket', 'staging_bucket', 'public_bucket'])
""" API independent configuration for the object storage. """
CeleryConfig = namedtuple('Celery', ['broker_url'])
CeleryConfig = namedtuple('Celery', ['broker_url', 'max_memory', 'timeout'])
""" Used to configure the RabbitMQ for celery. """
FSConfig = namedtuple('FSConfig', ['tmp', 'objects'])
......@@ -81,7 +81,9 @@ def get_loglevel_from_env(key, default_level=logging.INFO):
celery = CeleryConfig(
broker_url=rabbit_url
broker_url=rabbit_url,
max_memory=int(os.environ.get('NOMAD_CELERY_MAXMEMORY', 64e6)), # 64 GB
timeout=int(os.environ.get('NOMAD_CELERY_TIMEOUT', 3 * 3600)) # 3h
)
fs = FSConfig(
......
......@@ -104,6 +104,7 @@ def match_parser(mainfile: str, upload_files: files.StagingUploadFiles) -> 'Pars
for parser in parsers:
if parser.is_mainfile(mainfile_path, mime_type, buffer.decode('utf-8'), compression):
# TODO: deal with multiple possible parser specs
return parser
return None
......
......@@ -151,8 +151,11 @@ class ChaosParser(ArtificalParser):
time.sleep(1)
elif chaos == 'consume_ram':
data = []
i = 0
while True:
data.append('a' * 10**6)
i += 1
logger.info('ate %d mb' % i)
elif chaos == 'exception':
raise Exception('Some chaos happened, muhuha...')
elif chaos == 'segfault':
......
......@@ -15,13 +15,13 @@
from typing import List, Any
import logging
import time
from celery import Celery
from celery import Celery, Task
from celery.worker.request import Request
from celery.signals import after_setup_task_logger, after_setup_logger, worker_process_init
from mongoengine import Document, StringField, ListField, DateTimeField, IntField, \
ValidationError, BooleanField
from billiard.exceptions import WorkerLostError
from mongoengine import Document, StringField, ListField, DateTimeField, ValidationError
from mongoengine.connection import MongoEngineConnectionError
from mongoengine.base.metaclasses import TopLevelDocumentMetaclass
from pymongo import ReturnDocument
from datetime import datetime
from nomad import config, utils, infrastructure
......@@ -46,7 +46,8 @@ def setup(**kwargs):
app = Celery('nomad.processing', broker=config.celery.broker_url)
app.conf.update(worker_hijack_root_logger=False)
app.conf.update(task_reject_on_worker_lost=True)
app.conf.update(worker_max_memory_per_child=config.celery.max_memory)
app.conf.update(task_time_limit=config.celery.timeout)
CREATED = 'CREATED'
PENDING = 'PENDING'
......@@ -68,6 +69,9 @@ class ProcNotRegistered(Exception): pass
class ProcessAlreadyRunning(Exception): pass
class ProcObjectDoesNotExist(Exception): pass
class ProcMetaclass(TopLevelDocumentMetaclass):
def __new__(cls, name, bases, attrs):
cls = super().__new__(cls, name, bases, attrs)
......@@ -130,7 +134,7 @@ class Proc(Document, metaclass=ProcMetaclass):
current_process = StringField(default=None)
process_status = StringField(default=None)
_celery_task_id = StringField(default=None)
# _celery_task_id = StringField(default=None)
@property
def tasks_running(self) -> bool:
......@@ -261,9 +265,18 @@ class Proc(Document, metaclass=ProcMetaclass):
assert self.tasks_status == RUNNING, 'Can only complete a running process, process is %s' % self.tasks_status
self.tasks_status = SUCCESS
self.complete_time = datetime.now()
self.on_tasks_complete()
self.save()
self.get_logger().info('completed process')
def on_tasks_complete(self):
""" Callback that is called when the list of task are completed """
pass
def on_process_complete(self, process_name):
""" Callback that is called when the corrent process completed """
pass
def block_until_complete(self, interval=0.01):
"""
Reloads the process constantly until it sees a completed process. Should be
......@@ -274,113 +287,6 @@ class Proc(Document, metaclass=ProcMetaclass):
self.reload()
class InvalidChordUsage(Exception): pass
class Chord(Proc):
"""
A special Proc base class that manages a chord of child processes. It saves some
additional state to track child processes and provides methods to control that
state.
It uses a counter approach with atomic updates to track the number of processed
children.
TODO the joined attribute is not strictly necessary and only serves debugging purposes.
Maybe it should be removed, since it also requires another save.
TODO it is vital that sub classes and children don't miss any calls. This might
not be practical, because in reality processes might even fail to fail.
TODO in the current upload processing, the join functionality is not strictly necessary.
Nothing is done after join. We only need it to report the upload completed on API
request. We could check the join condition on each of thise API queries.
Attributes:
total_children (int): the number of spawed children, -1 denotes that number was not
saved yet
completed_children (int): the number of completed child procs
joined (bool): true if all children are completed and the join method was already called
"""
total_children = IntField(default=-1)
completed_children = IntField(default=0)
joined = BooleanField(default=False)
meta = {
'abstract': True
}
def spwaned_childred(self, total_children=1):
"""
Subclasses must call this method after all childred have been spawned.
Arguments:
total_children (int): the number of spawned children
"""
self.total_children = total_children
self.modify(total_children=self.total_children)
self._check_join(children=0)
def completed_child(self):
""" Children must call this, when they completed processing. """
self._check_join(children=1)
def _check_join(self, children):
# incr the counter and get reference values atomically
completed_children, others = self.incr_counter(
'completed_children', children, ['total_children', 'joined'])
total_children, joined = others
self.get_logger().debug(
'check for join', total_children=total_children,
completed_children=completed_children, joined=joined)
# check the join condition and raise errors if chord is in bad state
if completed_children == total_children:
if not joined:
self.join()
self.joined = True
self.modify(joined=self.joined)
self.get_logger().debug('chord is joined')
else:
raise InvalidChordUsage('chord cannot be joined twice.')
elif completed_children > total_children and total_children != -1:
raise InvalidChordUsage('chord counter is out of limits.')
def join(self):
""" Subclasses might overwrite to do something after all children have completed. """
pass
def incr_counter(self, field, value=1, other_fields=None):
"""
Atomically increases the given field by value and return the new value.
Optionally return also other values from the updated object to avoid
reloads.
Arguments:
field: the name of the field to increment, must be a :class:`IntField`
value: the value to increment the field by, default is 1
other_fields: an optional list of field names that should also be returned
Returns:
either the value of the updated field, or a tuple with this value and a list
of value for the given other fields
"""
# use a primitive but atomic pymongo call
updated_raw = self._get_collection().find_one_and_update(
{'_id': self.id},
{'$inc': {field: value}},
return_document=ReturnDocument.AFTER)
if updated_raw is None:
raise KeyError('object does not exist, was probaly not yet written to db')
if other_fields is None:
return updated_raw[field]
else:
return updated_raw[field], [updated_raw[field] for field in other_fields]
def task(func):
"""
The decorator for tasks that will be wrapped in exception handling that will fail the process.
......@@ -418,19 +324,56 @@ all_proc_cls = {cls.__name__: cls for cls in all_subclasses(Proc)}
""" Name dictionary for all Proc classes. """
@app.task(bind=True, ignore_results=True, max_retries=3, acks_late=True)
def proc_task(task, cls_name, self_id, func_attr):
class NomadCeleryRequest(Request):
"""
The celery task that is used to execute async process functions.
It ignores results, since all results are handled via the self document.
It retries for 3 times with a countdown of 3 on missing 'selfs', since this
might happen in sharded, distributed mongo setups where the object might not
have yet been propagated and therefore appear missing.
A custom celery request class that allows to catch error in the worker main
thread, which cannot be caught on the worker threads themselves.
"""
def _fail(self, event, **kwargs):
args = self._payload[0]
# this might be run in the worker main thread, which does not have a mongo
# connection by default
if infrastructure.mongo_client is None:
infrastructure.setup_mongo()
if infrastructure.repository_db is None:
infrastructure.setup_repository_db()
proc = unwarp_task(self.task, *args)
proc.fail('task timeout occurred', **kwargs)
proc.process_status = PROCESS_COMPLETED
proc.on_process_complete(None)
proc.save()
def on_timeout(self, soft, timeout):
if not soft:
self._fail('task timeout occurred', timeout=timeout)
super().on_timeout(soft, timeout)
def on_failure(self, exc_info, send_failed_event=True, return_ok=False):
if isinstance(exc_info.exception, WorkerLostError):
self._fail(
'task failed due to worker lost: %s' % str(exc_info.exception),
exc_info=exc_info)
super().on_failure(
exc_info,
send_failed_event=send_failed_event,
return_ok=return_ok
)
class NomadCeleryTask(Task):
Request = NomadCeleryRequest
def unwarp_task(task, cls_name, self_id, *args, **kwargs):
"""
Retrieves the proc object that the given task is executed on from the database.
"""
logger = utils.get_logger(__name__, cls=cls_name, id=self_id, func=func_attr)
logger = utils.get_logger(__name__, cls=cls_name, id=self_id)
# get the process class
logger.debug('received process function call')
global all_proc_cls
cls = all_proc_cls.get(cls_name, None)
if cls is None:
......@@ -451,8 +394,24 @@ def proc_task(task, cls_name, self_id, func_attr):
raise task.retry(exc=e, countdown=3)
except KeyError:
logger.critical('called object is missing, retries exeeded')
raise ProcObjectDoesNotExist()
return self
@app.task(bind=True, base=NomadCeleryTask, ignore_results=True, max_retries=3, acks_late=True)
def proc_task(task, cls_name, self_id, func_attr):
"""
The celery task that is used to execute async process functions.
It ignores results, since all results are handled via the self document.
It retries for 3 times with a countdown of 3 on missing 'selfs', since this
might happen in sharded, distributed mongo setups where the object might not
have yet been propagated and therefore appear missing.
"""
self = unwarp_task(task, cls_name, self_id)
logger = self.get_logger()
logger.debug('received process function call')
# get the process function
func = getattr(self, func_attr, None)
......@@ -469,18 +428,10 @@ def proc_task(task, cls_name, self_id, func_attr):
logger.error('called function was not decorated with @process')
self.fail('called function %s was not decorated with @process' % func_attr)
self.process_status = PROCESS_COMPLETED
self.on_process_complete(None)
self.save()
return
# check requeued task after catastrophic failure, e.g. segfault
if self._celery_task_id is not None:
if self._celery_task_id == task.request.id and task.request.retries == 0:
self.fail('task failed catastrophically, probably with sys.exit or segfault')
if self._celery_task_id != task.request.id:
self._celery_task_id = task.request.id
self.save()
# call the process function
deleted = False
try:
......@@ -491,6 +442,7 @@ def proc_task(task, cls_name, self_id, func_attr):
finally:
if deleted is None or not deleted:
self.process_status = PROCESS_COMPLETED
self.on_process_complete(func.__name__)
self.save()
......
......@@ -32,7 +32,7 @@ from contextlib import contextmanager
from nomad import utils, coe_repo, config, infrastructure, search
from nomad.files import PathObject, UploadFiles, ExtractError, ArchiveBasedStagingUploadFiles
from nomad.processing.base import Proc, Chord, process, task, PENDING, SUCCESS, FAILURE
from nomad.processing.base import Proc, process, task, PENDING, SUCCESS, FAILURE
from nomad.parsing import parser_dict, match_parser
from nomad.normalizing import normalizers
from nomad.datamodel import UploadWithMetadata, CalcWithMetadata
......@@ -141,8 +141,13 @@ class Calc(Proc):
except Exception as e:
logger.error('could not close calculation proc log', exc_info=e)
# inform parent proc about completion
self.upload.completed_child()
def on_process_complete(self, process_name):
# the save might be necessary to correctly read the join condition from the db
self.save()
# in case of error, the process_name might be unknown
if process_name == 'process_calc' or process_name is None:
self.upload.reload()
self.upload.check_join()
@task
def parsing(self):
......@@ -270,7 +275,7 @@ class Calc(Proc):
return CalcWithMetadata(**_data)
class Upload(Chord):
class Upload(Proc):
"""
Represents uploads in the databases. Provides persistence access to the files storage,
and processing state.
......@@ -467,11 +472,9 @@ class Upload(Chord):
"""
logger = self.get_logger()
# TODO: deal with multiple possible parser specs
with utils.timer(
logger, 'upload extracted', step='matching',
upload_size=self.upload_files.size):
total_calcs = 0
for filename, parser in self.match_mainfiles():
calc = Calc.create(
calc_id=self.upload_files.calc_id(filename),
......@@ -479,10 +482,19 @@ class Upload(Chord):
upload_id=self.upload_id)
calc.process_calc()
total_calcs += 1
# have to save the total_calcs information for chord management
self.spwaned_childred(total_calcs)
def on_process_complete(self, process_name):
if process_name == 'process_upload':
self.check_join()
def check_join(self):
total_calcs = self.total_calcs
processed_calcs = self.processed_calcs
self.get_logger().debug('check join', processed_calcs=processed_calcs, total_calcs=total_calcs)
if not self.process_running and processed_calcs >= total_calcs:
self.get_logger().debug('join')
self.join()
def join(self):
self.cleanup()
......
......@@ -100,7 +100,7 @@ class LogstashHandler(logstash.TCPLogstashHandler):
return True
else:
LogstashHandler.legacy_logger.log(
record.levelno, record.msg, args=record.args,
record.levelno, sanitize_logevent(record.msg), args=record.args,
exc_info=record.exc_info, stack_info=record.stack_info,
legacy_logger=record.name)
......@@ -121,7 +121,7 @@ class LogstashFormatter(logstash.formatter.LogstashFormatterBase):
message = {
'@timestamp': self.format_timestamp(record.created),
'@version': '1',
'event': sanitize_logevent(structlog['event']),
'event': structlog['event'],
'message': structlog['event'],
'host': self.host,
'path': record.pathname,
......
......@@ -22,6 +22,11 @@ spec:
containers:
- name: {{ include "nomad.name" . }}-worker
image: "{{ .Values.images.nomad.name }}:{{ .Values.images.nomad.tag }}"
resources:
limits:
memory: "{{ .Values.worker.memlimit }}Gi"
requests:
memory: "{{ .Values.worker.memrequest }}Gi"
volumeMounts:
- mountPath: /app/.volumes/fs
name: files-volume
......@@ -74,7 +79,7 @@ spec:
livenessProbe:
exec:
command:
- bash
- bash
- -c
- NOMAD_LOGSTASH_LEVEL=30 python -m celery -A nomad.processing status | grep "${HOSTNAME}:.*OK"
initialDelaySeconds: 30
......@@ -82,7 +87,7 @@ spec:
readinessProbe:
exec:
command:
- bash
- bash
- -c
- NOMAD_LOGSTASH_LEVEL=30 python -m celery -A nomad.processing status | grep "${HOSTNAME}:.*OK"
initialDelaySeconds: 5
......
......@@ -40,6 +40,9 @@ api:
## Everthing concerning the nomad worker
worker:
replicas: 2
# request and limit in GB
memrequest: 192
memlimit: 320
console_loglevel: INFO
logstash_loglevel: INFO
......
......@@ -131,9 +131,9 @@ def celery_inspect(purged_app):
# It might be necessary to make this a function scoped fixture, if old tasks keep
# 'bleeding' into successive tests.
@pytest.fixture(scope='function')
def worker(celery_session_worker, celery_inspect):
def worker(mongo, celery_session_worker, celery_inspect):
""" Provides a clean worker (no old tasks) per function. Waits for all tasks to be completed. """
pass
yield
# wait until there no more active tasks, to leave clean worker and queues for the next
# test run.
......@@ -250,7 +250,7 @@ def postgres_infra(monkeysession):
@pytest.fixture(scope='function')
def proc_infra(postgres, elastic, mongo, worker):
def proc_infra(worker, postgres, elastic, mongo, raw_files):
""" Combines all fixtures necessary for processing (postgres, elastic, worker, files, mongo) """
return dict(
postgres=postgres,
......
import pytest
from mongoengine import ReferenceField
import time
import json
import random
import time
from nomad.processing.base import Proc, Chord, process, task, SUCCESS, FAILURE, RUNNING, PENDING
from nomad.processing.base import Proc, process, task, SUCCESS, FAILURE, RUNNING, PENDING
random.seed(0)
......@@ -84,7 +81,7 @@ class SimpleProc(Proc):
pass
def test_simple_process(mongo, worker, no_warn):
def test_simple_process(worker, mongo, no_warn):
p = SimpleProc.create()
p.process()
p.block_until_complete()
......@@ -99,7 +96,7 @@ class TaskInProc(Proc):
@pytest.mark.timeout(5)
def test_task_as_proc(mongo, worker, no_warn):
def test_task_as_proc(worker, mongo, no_warn):
p = TaskInProc.create()
p.process()
p.block_until_complete()
......@@ -118,46 +115,8 @@ class ProcInProc(Proc):
pass
def test_fail_on_proc_in_proc(mongo, worker):
def test_fail_on_proc_in_proc(worker, mongo):
p = ProcInProc.create()
p.one()
p.block_until_complete()
assert_proc(p, 'one', FAILURE, 1)
class ParentProc(Chord):
@process
@task
def spawn_children(self):
count = 23
for _ in range(0, count):
ChildProc.create(parent=self).process()
self.spwaned_childred(count)
@task
def join(self):
pass
class ChildProc(Proc):
parent = ReferenceField(ParentProc)
@process
@task
def process(self):
time.sleep(random.uniform(0, 0.1))
self.parent.completed_child()
@pytest.mark.timeout(10)
def test_counter(mongo, worker, no_warn):
p = ParentProc.create()
p.spawn_children()
p.block_until_complete()
p = ParentProc.get(p.id)
assert_proc(p, 'join')
# TODO there seems to be a bug, that makes this fail from time to time.
# assert p.joined
......@@ -12,12 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this test, a celery worker must be running. The test worker provided by
the celery pytest plugin is currently not working. It results on a timeout when
reading from the redis result backend, even though all task apperently ended successfully.
"""
from typing import Generator
import pytest
from datetime import datetime
......@@ -127,7 +121,7 @@ def test_process_non_existing(proc_infra, test_user, with_error):
@pytest.mark.parametrize('task', ['extracting', 'parse_all', 'cleanup', 'parsing'])
@pytest.mark.timeout(10)
def test_task_failure(monkeypatch, uploaded, worker, task, proc_infra, test_user, with_error):
def test_task_failure(monkeypatch, uploaded, task, proc_infra, test_user, with_error):
# mock the task method to through exceptions
if hasattr(Upload, task):
cls = Upload
......@@ -163,3 +157,26 @@ def test_task_failure(monkeypatch, uploaded, worker, task, proc_infra, test_user
assert calc.tasks_status == FAILURE
assert calc.current_task == 'parsing'
assert len(calc.errors) > 0
# TODO timeout
# consume_ram, segfault, and exit are not testable with the celery test worker
@pytest.mark.parametrize('failure', ['exception'])
def test_malicious_parser_task_failure(proc_infra, failure, test_user):
example_file = 'tests/data/proc/chaos_%s.zip' % failure
example_upload_id = os.path.basename(example_file).replace('.zip', '')
upload_files = ArchiveBasedStagingUploadFiles(example_upload_id, create=True)
shutil.copyfile(example_file, upload_files.upload_file_os_path)
upload = run_processing(example_upload_id, test_user)
assert not upload.tasks_running
assert upload.current_task == 'cleanup'
assert len(upload.errors) == 0
assert upload.tasks_status == SUCCESS
calcs = Calc.objects(upload_id=upload.upload_id)
assert calcs.count() == 1