diff --git a/nomad/config.py b/nomad/config.py index 03aef20c9946341b0e0ae42cd214fdc8aade3f95..c9dcfe1186d507d06c1af0b2dbf4cc3d31ee5831 100644 --- a/nomad/config.py +++ b/nomad/config.py @@ -30,7 +30,7 @@ FilesConfig = namedtuple( 'FilesConfig', ['uploads_bucket', 'raw_bucket', 'archive_bucket', 'staging_bucket', 'public_bucket']) """ API independent configuration for the object storage. """ -CeleryConfig = namedtuple('Celery', ['broker_url']) +CeleryConfig = namedtuple('Celery', ['broker_url', 'max_memory', 'timeout']) """ Used to configure the RabbitMQ for celery. """ FSConfig = namedtuple('FSConfig', ['tmp', 'objects']) @@ -81,7 +81,9 @@ def get_loglevel_from_env(key, default_level=logging.INFO): celery = CeleryConfig( - broker_url=rabbit_url + broker_url=rabbit_url, + max_memory=int(os.environ.get('NOMAD_CELERY_MAXMEMORY', 64e6)), # 64 GB + timeout=int(os.environ.get('NOMAD_CELERY_TIMEOUT', 3 * 3600)) # 3h ) fs = FSConfig( diff --git a/nomad/parsing/__init__.py b/nomad/parsing/__init__.py index cfa4b4716b6f2aa1e9a6e9a4123db33eb8615f75..a06c486eade69ffda18f15a3be45d87080028bed 100644 --- a/nomad/parsing/__init__.py +++ b/nomad/parsing/__init__.py @@ -104,6 +104,7 @@ def match_parser(mainfile: str, upload_files: files.StagingUploadFiles) -> 'Pars for parser in parsers: if parser.is_mainfile(mainfile_path, mime_type, buffer.decode('utf-8'), compression): + # TODO: deal with multiple possible parser specs return parser return None diff --git a/nomad/parsing/artificial.py b/nomad/parsing/artificial.py index 836debe1af57f1cffcdae0c13b870e8dc6b3314a..2fabcea114c9336350b5d0c35ff567faba302f1b 100644 --- a/nomad/parsing/artificial.py +++ b/nomad/parsing/artificial.py @@ -151,8 +151,11 @@ class ChaosParser(ArtificalParser): time.sleep(1) elif chaos == 'consume_ram': data = [] + i = 0 while True: data.append('a' * 10**6) + i += 1 + logger.info('ate %d mb' % i) elif chaos == 'exception': raise Exception('Some chaos happened, muhuha...') elif chaos == 'segfault': diff --git a/nomad/processing/base.py b/nomad/processing/base.py index 3266ad04c7121966c52930672625d601c28271e9..a7f73ef4e3da4b31438adcc33af27d5a105fbbf1 100644 --- a/nomad/processing/base.py +++ b/nomad/processing/base.py @@ -15,13 +15,13 @@ from typing import List, Any import logging import time -from celery import Celery +from celery import Celery, Task +from celery.worker.request import Request from celery.signals import after_setup_task_logger, after_setup_logger, worker_process_init -from mongoengine import Document, StringField, ListField, DateTimeField, IntField, \ - ValidationError, BooleanField +from billiard.exceptions import WorkerLostError +from mongoengine import Document, StringField, ListField, DateTimeField, ValidationError from mongoengine.connection import MongoEngineConnectionError from mongoengine.base.metaclasses import TopLevelDocumentMetaclass -from pymongo import ReturnDocument from datetime import datetime from nomad import config, utils, infrastructure @@ -46,7 +46,8 @@ def setup(**kwargs): app = Celery('nomad.processing', broker=config.celery.broker_url) app.conf.update(worker_hijack_root_logger=False) -app.conf.update(task_reject_on_worker_lost=True) +app.conf.update(worker_max_memory_per_child=config.celery.max_memory) +app.conf.update(task_time_limit=config.celery.timeout) CREATED = 'CREATED' PENDING = 'PENDING' @@ -68,6 +69,9 @@ class ProcNotRegistered(Exception): pass class ProcessAlreadyRunning(Exception): pass +class ProcObjectDoesNotExist(Exception): pass + + class ProcMetaclass(TopLevelDocumentMetaclass): def __new__(cls, name, bases, attrs): cls = super().__new__(cls, name, bases, attrs) @@ -130,7 +134,7 @@ class Proc(Document, metaclass=ProcMetaclass): current_process = StringField(default=None) process_status = StringField(default=None) - _celery_task_id = StringField(default=None) + # _celery_task_id = StringField(default=None) @property def tasks_running(self) -> bool: @@ -261,9 +265,18 @@ class Proc(Document, metaclass=ProcMetaclass): assert self.tasks_status == RUNNING, 'Can only complete a running process, process is %s' % self.tasks_status self.tasks_status = SUCCESS self.complete_time = datetime.now() + self.on_tasks_complete() self.save() self.get_logger().info('completed process') + def on_tasks_complete(self): + """ Callback that is called when the list of task are completed """ + pass + + def on_process_complete(self, process_name): + """ Callback that is called when the corrent process completed """ + pass + def block_until_complete(self, interval=0.01): """ Reloads the process constantly until it sees a completed process. Should be @@ -274,113 +287,6 @@ class Proc(Document, metaclass=ProcMetaclass): self.reload() -class InvalidChordUsage(Exception): pass - - -class Chord(Proc): - """ - A special Proc base class that manages a chord of child processes. It saves some - additional state to track child processes and provides methods to control that - state. - - It uses a counter approach with atomic updates to track the number of processed - children. - - TODO the joined attribute is not strictly necessary and only serves debugging purposes. - Maybe it should be removed, since it also requires another save. - - TODO it is vital that sub classes and children don't miss any calls. This might - not be practical, because in reality processes might even fail to fail. - - TODO in the current upload processing, the join functionality is not strictly necessary. - Nothing is done after join. We only need it to report the upload completed on API - request. We could check the join condition on each of thise API queries. - - Attributes: - total_children (int): the number of spawed children, -1 denotes that number was not - saved yet - completed_children (int): the number of completed child procs - joined (bool): true if all children are completed and the join method was already called - """ - total_children = IntField(default=-1) - completed_children = IntField(default=0) - joined = BooleanField(default=False) - - meta = { - 'abstract': True - } - - def spwaned_childred(self, total_children=1): - """ - Subclasses must call this method after all childred have been spawned. - - Arguments: - total_children (int): the number of spawned children - """ - self.total_children = total_children - self.modify(total_children=self.total_children) - self._check_join(children=0) - - def completed_child(self): - """ Children must call this, when they completed processing. """ - self._check_join(children=1) - - def _check_join(self, children): - # incr the counter and get reference values atomically - completed_children, others = self.incr_counter( - 'completed_children', children, ['total_children', 'joined']) - total_children, joined = others - - self.get_logger().debug( - 'check for join', total_children=total_children, - completed_children=completed_children, joined=joined) - - # check the join condition and raise errors if chord is in bad state - if completed_children == total_children: - if not joined: - self.join() - self.joined = True - self.modify(joined=self.joined) - self.get_logger().debug('chord is joined') - else: - raise InvalidChordUsage('chord cannot be joined twice.') - elif completed_children > total_children and total_children != -1: - raise InvalidChordUsage('chord counter is out of limits.') - - def join(self): - """ Subclasses might overwrite to do something after all children have completed. """ - pass - - def incr_counter(self, field, value=1, other_fields=None): - """ - Atomically increases the given field by value and return the new value. - Optionally return also other values from the updated object to avoid - reloads. - - Arguments: - field: the name of the field to increment, must be a :class:`IntField` - value: the value to increment the field by, default is 1 - other_fields: an optional list of field names that should also be returned - - Returns: - either the value of the updated field, or a tuple with this value and a list - of value for the given other fields - """ - # use a primitive but atomic pymongo call - updated_raw = self._get_collection().find_one_and_update( - {'_id': self.id}, - {'$inc': {field: value}}, - return_document=ReturnDocument.AFTER) - - if updated_raw is None: - raise KeyError('object does not exist, was probaly not yet written to db') - - if other_fields is None: - return updated_raw[field] - else: - return updated_raw[field], [updated_raw[field] for field in other_fields] - - def task(func): """ The decorator for tasks that will be wrapped in exception handling that will fail the process. @@ -418,19 +324,56 @@ all_proc_cls = {cls.__name__: cls for cls in all_subclasses(Proc)} """ Name dictionary for all Proc classes. """ -@app.task(bind=True, ignore_results=True, max_retries=3, acks_late=True) -def proc_task(task, cls_name, self_id, func_attr): +class NomadCeleryRequest(Request): """ - The celery task that is used to execute async process functions. - It ignores results, since all results are handled via the self document. - It retries for 3 times with a countdown of 3 on missing 'selfs', since this - might happen in sharded, distributed mongo setups where the object might not - have yet been propagated and therefore appear missing. + A custom celery request class that allows to catch error in the worker main + thread, which cannot be caught on the worker threads themselves. + """ + + def _fail(self, event, **kwargs): + args = self._payload[0] + # this might be run in the worker main thread, which does not have a mongo + # connection by default + if infrastructure.mongo_client is None: + infrastructure.setup_mongo() + if infrastructure.repository_db is None: + infrastructure.setup_repository_db() + proc = unwarp_task(self.task, *args) + proc.fail('task timeout occurred', **kwargs) + proc.process_status = PROCESS_COMPLETED + proc.on_process_complete(None) + proc.save() + + def on_timeout(self, soft, timeout): + if not soft: + self._fail('task timeout occurred', timeout=timeout) + + super().on_timeout(soft, timeout) + + def on_failure(self, exc_info, send_failed_event=True, return_ok=False): + if isinstance(exc_info.exception, WorkerLostError): + self._fail( + 'task failed due to worker lost: %s' % str(exc_info.exception), + exc_info=exc_info) + + super().on_failure( + exc_info, + send_failed_event=send_failed_event, + return_ok=return_ok + ) + + +class NomadCeleryTask(Task): + Request = NomadCeleryRequest + + +def unwarp_task(task, cls_name, self_id, *args, **kwargs): + """ + Retrieves the proc object that the given task is executed on from the database. """ - logger = utils.get_logger(__name__, cls=cls_name, id=self_id, func=func_attr) + logger = utils.get_logger(__name__, cls=cls_name, id=self_id) # get the process class - logger.debug('received process function call') global all_proc_cls cls = all_proc_cls.get(cls_name, None) if cls is None: @@ -451,8 +394,24 @@ def proc_task(task, cls_name, self_id, func_attr): raise task.retry(exc=e, countdown=3) except KeyError: logger.critical('called object is missing, retries exeeded') + raise ProcObjectDoesNotExist() + + return self + + +@app.task(bind=True, base=NomadCeleryTask, ignore_results=True, max_retries=3, acks_late=True) +def proc_task(task, cls_name, self_id, func_attr): + """ + The celery task that is used to execute async process functions. + It ignores results, since all results are handled via the self document. + It retries for 3 times with a countdown of 3 on missing 'selfs', since this + might happen in sharded, distributed mongo setups where the object might not + have yet been propagated and therefore appear missing. + """ + self = unwarp_task(task, cls_name, self_id) logger = self.get_logger() + logger.debug('received process function call') # get the process function func = getattr(self, func_attr, None) @@ -469,18 +428,10 @@ def proc_task(task, cls_name, self_id, func_attr): logger.error('called function was not decorated with @process') self.fail('called function %s was not decorated with @process' % func_attr) self.process_status = PROCESS_COMPLETED + self.on_process_complete(None) self.save() return - # check requeued task after catastrophic failure, e.g. segfault - if self._celery_task_id is not None: - if self._celery_task_id == task.request.id and task.request.retries == 0: - self.fail('task failed catastrophically, probably with sys.exit or segfault') - - if self._celery_task_id != task.request.id: - self._celery_task_id = task.request.id - self.save() - # call the process function deleted = False try: @@ -491,6 +442,7 @@ def proc_task(task, cls_name, self_id, func_attr): finally: if deleted is None or not deleted: self.process_status = PROCESS_COMPLETED + self.on_process_complete(func.__name__) self.save() diff --git a/nomad/processing/data.py b/nomad/processing/data.py index 25c7bdf69193133d27db0a9f5b22d6738e1ff976..f971563f4013bb0043ce6442c29633d92a3345a5 100644 --- a/nomad/processing/data.py +++ b/nomad/processing/data.py @@ -32,7 +32,7 @@ from contextlib import contextmanager from nomad import utils, coe_repo, config, infrastructure, search from nomad.files import PathObject, UploadFiles, ExtractError, ArchiveBasedStagingUploadFiles -from nomad.processing.base import Proc, Chord, process, task, PENDING, SUCCESS, FAILURE +from nomad.processing.base import Proc, process, task, PENDING, SUCCESS, FAILURE from nomad.parsing import parser_dict, match_parser from nomad.normalizing import normalizers from nomad.datamodel import UploadWithMetadata, CalcWithMetadata @@ -141,8 +141,13 @@ class Calc(Proc): except Exception as e: logger.error('could not close calculation proc log', exc_info=e) - # inform parent proc about completion - self.upload.completed_child() + def on_process_complete(self, process_name): + # the save might be necessary to correctly read the join condition from the db + self.save() + # in case of error, the process_name might be unknown + if process_name == 'process_calc' or process_name is None: + self.upload.reload() + self.upload.check_join() @task def parsing(self): @@ -270,7 +275,7 @@ class Calc(Proc): return CalcWithMetadata(**_data) -class Upload(Chord): +class Upload(Proc): """ Represents uploads in the databases. Provides persistence access to the files storage, and processing state. @@ -467,11 +472,9 @@ class Upload(Chord): """ logger = self.get_logger() - # TODO: deal with multiple possible parser specs with utils.timer( logger, 'upload extracted', step='matching', upload_size=self.upload_files.size): - total_calcs = 0 for filename, parser in self.match_mainfiles(): calc = Calc.create( calc_id=self.upload_files.calc_id(filename), @@ -479,10 +482,19 @@ class Upload(Chord): upload_id=self.upload_id) calc.process_calc() - total_calcs += 1 - # have to save the total_calcs information for chord management - self.spwaned_childred(total_calcs) + def on_process_complete(self, process_name): + if process_name == 'process_upload': + self.check_join() + + def check_join(self): + total_calcs = self.total_calcs + processed_calcs = self.processed_calcs + + self.get_logger().debug('check join', processed_calcs=processed_calcs, total_calcs=total_calcs) + if not self.process_running and processed_calcs >= total_calcs: + self.get_logger().debug('join') + self.join() def join(self): self.cleanup() diff --git a/nomad/utils.py b/nomad/utils.py index cf98473c26091f2bf5bafc2202e5b0fe487d6180..d89ebff27c2f7bcb24c38a321c7ec7affe48394f 100644 --- a/nomad/utils.py +++ b/nomad/utils.py @@ -100,7 +100,7 @@ class LogstashHandler(logstash.TCPLogstashHandler): return True else: LogstashHandler.legacy_logger.log( - record.levelno, record.msg, args=record.args, + record.levelno, sanitize_logevent(record.msg), args=record.args, exc_info=record.exc_info, stack_info=record.stack_info, legacy_logger=record.name) @@ -121,7 +121,7 @@ class LogstashFormatter(logstash.formatter.LogstashFormatterBase): message = { '@timestamp': self.format_timestamp(record.created), '@version': '1', - 'event': sanitize_logevent(structlog['event']), + 'event': structlog['event'], 'message': structlog['event'], 'host': self.host, 'path': record.pathname, diff --git a/ops/helm/nomad/templates/worker-deployment.yaml b/ops/helm/nomad/templates/worker-deployment.yaml index a2a2958bfe90521a655dfa6f6c95f38a944a45d5..5098e8dcebc8ddae9f8537da1bb97cf6acd75c5a 100644 --- a/ops/helm/nomad/templates/worker-deployment.yaml +++ b/ops/helm/nomad/templates/worker-deployment.yaml @@ -22,6 +22,11 @@ spec: containers: - name: {{ include "nomad.name" . }}-worker image: "{{ .Values.images.nomad.name }}:{{ .Values.images.nomad.tag }}" + resources: + limits: + memory: "{{ .Values.worker.memlimit }}Gi" + requests: + memory: "{{ .Values.worker.memrequest }}Gi" volumeMounts: - mountPath: /app/.volumes/fs name: files-volume @@ -74,7 +79,7 @@ spec: livenessProbe: exec: command: - - bash + - bash - -c - NOMAD_LOGSTASH_LEVEL=30 python -m celery -A nomad.processing status | grep "${HOSTNAME}:.*OK" initialDelaySeconds: 30 @@ -82,7 +87,7 @@ spec: readinessProbe: exec: command: - - bash + - bash - -c - NOMAD_LOGSTASH_LEVEL=30 python -m celery -A nomad.processing status | grep "${HOSTNAME}:.*OK" initialDelaySeconds: 5 diff --git a/ops/helm/nomad/values.yaml b/ops/helm/nomad/values.yaml index 689c9f428ed0d4c21754fd722623678e3d47a717..4c5a9ae8f64aea6b901f430fd0513719f2d9b01d 100644 --- a/ops/helm/nomad/values.yaml +++ b/ops/helm/nomad/values.yaml @@ -40,6 +40,9 @@ api: ## Everthing concerning the nomad worker worker: replicas: 2 + # request and limit in GB + memrequest: 192 + memlimit: 320 console_loglevel: INFO logstash_loglevel: INFO diff --git a/tests/conftest.py b/tests/conftest.py index 314dca817cea1b4da2cc2600fa5b171ddbd7f8a1..13f538ce47536448703495374137961c5ccafc53 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -131,9 +131,9 @@ def celery_inspect(purged_app): # It might be necessary to make this a function scoped fixture, if old tasks keep # 'bleeding' into successive tests. @pytest.fixture(scope='function') -def worker(celery_session_worker, celery_inspect): +def worker(mongo, celery_session_worker, celery_inspect): """ Provides a clean worker (no old tasks) per function. Waits for all tasks to be completed. """ - pass + yield # wait until there no more active tasks, to leave clean worker and queues for the next # test run. @@ -250,7 +250,7 @@ def postgres_infra(monkeysession): @pytest.fixture(scope='function') -def proc_infra(postgres, elastic, mongo, worker): +def proc_infra(worker, postgres, elastic, mongo, raw_files): """ Combines all fixtures necessary for processing (postgres, elastic, worker, files, mongo) """ return dict( postgres=postgres, diff --git a/tests/data/proc/chaos_consume_ram.zip b/tests/data/proc/chaos_consume_ram.zip new file mode 100644 index 0000000000000000000000000000000000000000..4b6819480c2a772421d741f62ecd0707abeced1d Binary files /dev/null and b/tests/data/proc/chaos_consume_ram.zip differ diff --git a/tests/data/proc/chaos_deadlock.zip b/tests/data/proc/chaos_deadlock.zip new file mode 100644 index 0000000000000000000000000000000000000000..3365d2da75f6d444155853f5d988499a1ee14aaa Binary files /dev/null and b/tests/data/proc/chaos_deadlock.zip differ diff --git a/tests/data/proc/chaos_exit.zip b/tests/data/proc/chaos_exit.zip new file mode 100644 index 0000000000000000000000000000000000000000..4f5ad5a4935ea9df2c85eed6c7b1b0931491229c Binary files /dev/null and b/tests/data/proc/chaos_exit.zip differ diff --git a/tests/processing/test_base.py b/tests/processing/test_base.py index ec524b05475ef7653c2b49383cb071bdde990510..d3f7d1e203a790f9b9c2f17688a6703ecd9be51d 100644 --- a/tests/processing/test_base.py +++ b/tests/processing/test_base.py @@ -1,11 +1,8 @@ import pytest -from mongoengine import ReferenceField -import time import json import random -import time -from nomad.processing.base import Proc, Chord, process, task, SUCCESS, FAILURE, RUNNING, PENDING +from nomad.processing.base import Proc, process, task, SUCCESS, FAILURE, RUNNING, PENDING random.seed(0) @@ -84,7 +81,7 @@ class SimpleProc(Proc): pass -def test_simple_process(mongo, worker, no_warn): +def test_simple_process(worker, mongo, no_warn): p = SimpleProc.create() p.process() p.block_until_complete() @@ -99,7 +96,7 @@ class TaskInProc(Proc): @pytest.mark.timeout(5) -def test_task_as_proc(mongo, worker, no_warn): +def test_task_as_proc(worker, mongo, no_warn): p = TaskInProc.create() p.process() p.block_until_complete() @@ -118,46 +115,8 @@ class ProcInProc(Proc): pass -def test_fail_on_proc_in_proc(mongo, worker): +def test_fail_on_proc_in_proc(worker, mongo): p = ProcInProc.create() p.one() p.block_until_complete() assert_proc(p, 'one', FAILURE, 1) - - -class ParentProc(Chord): - - @process - @task - def spawn_children(self): - count = 23 - for _ in range(0, count): - ChildProc.create(parent=self).process() - - self.spwaned_childred(count) - - @task - def join(self): - pass - - -class ChildProc(Proc): - parent = ReferenceField(ParentProc) - - @process - @task - def process(self): - time.sleep(random.uniform(0, 0.1)) - self.parent.completed_child() - - -@pytest.mark.timeout(10) -def test_counter(mongo, worker, no_warn): - p = ParentProc.create() - p.spawn_children() - p.block_until_complete() - - p = ParentProc.get(p.id) - assert_proc(p, 'join') - # TODO there seems to be a bug, that makes this fail from time to time. - # assert p.joined diff --git a/tests/processing/test_data.py b/tests/processing/test_data.py index 21a62deb72d595d26558fa68f27413a2aed1c795..109d409283dfffae5d39aeddf84cba5aa1c178de 100644 --- a/tests/processing/test_data.py +++ b/tests/processing/test_data.py @@ -12,12 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -To run this test, a celery worker must be running. The test worker provided by -the celery pytest plugin is currently not working. It results on a timeout when -reading from the redis result backend, even though all task apperently ended successfully. -""" - from typing import Generator import pytest from datetime import datetime @@ -127,7 +121,7 @@ def test_process_non_existing(proc_infra, test_user, with_error): @pytest.mark.parametrize('task', ['extracting', 'parse_all', 'cleanup', 'parsing']) @pytest.mark.timeout(10) -def test_task_failure(monkeypatch, uploaded, worker, task, proc_infra, test_user, with_error): +def test_task_failure(monkeypatch, uploaded, task, proc_infra, test_user, with_error): # mock the task method to through exceptions if hasattr(Upload, task): cls = Upload @@ -163,3 +157,26 @@ def test_task_failure(monkeypatch, uploaded, worker, task, proc_infra, test_user assert calc.tasks_status == FAILURE assert calc.current_task == 'parsing' assert len(calc.errors) > 0 + +# TODO timeout +# consume_ram, segfault, and exit are not testable with the celery test worker +@pytest.mark.parametrize('failure', ['exception']) +def test_malicious_parser_task_failure(proc_infra, failure, test_user): + example_file = 'tests/data/proc/chaos_%s.zip' % failure + example_upload_id = os.path.basename(example_file).replace('.zip', '') + upload_files = ArchiveBasedStagingUploadFiles(example_upload_id, create=True) + shutil.copyfile(example_file, upload_files.upload_file_os_path) + + upload = run_processing(example_upload_id, test_user) + + assert not upload.tasks_running + assert upload.current_task == 'cleanup' + assert len(upload.errors) == 0 + assert upload.tasks_status == SUCCESS + + calcs = Calc.objects(upload_id=upload.upload_id) + assert calcs.count() == 1 + calc = next(calcs) + assert not calc.tasks_running + assert calc.tasks_status == FAILURE + assert len(calc.errors) == 1