Commit d8d18623 authored by Markus Scheidgen's avatar Markus Scheidgen
Browse files

Refactored test and their fixtures.

parent dd4919e7
...@@ -23,7 +23,7 @@ import celery ...@@ -23,7 +23,7 @@ import celery
from celery import Celery, Task from celery import Celery, Task
from celery.signals import after_setup_task_logger, after_setup_logger, worker_process_init from celery.signals import after_setup_task_logger, after_setup_logger, worker_process_init
from mongoengine import Document, StringField, ListField, DateTimeField, IntField, \ from mongoengine import Document, StringField, ListField, DateTimeField, IntField, \
ReferenceField, connect, ValidationError ReferenceField, connect, ValidationError, BooleanField, EmbeddedDocument
from mongoengine.connection import MongoEngineConnectionError from mongoengine.connection import MongoEngineConnectionError
from mongoengine.base.metaclasses import TopLevelDocumentMetaclass from mongoengine.base.metaclasses import TopLevelDocumentMetaclass
from pymongo import ReturnDocument from pymongo import ReturnDocument
...@@ -252,6 +252,105 @@ class Proc(Document, metaclass=ProcMetaclass): ...@@ -252,6 +252,105 @@ class Proc(Document, metaclass=ProcMetaclass):
self.save() self.save()
self.get_logger().debug('completed process') self.get_logger().debug('completed process')
def block_until_complete(self, interval=0.01):
"""
Reloads the process constrantly until it sees a completed process. Should be
used with care as it can block indefinetly. Just intended for testing purposes.
"""
while not self.completed:
time.sleep(interval)
self.reload()
@property
def json_dict(self) -> dict:
""" A json serializable dictionary representation. """
data = {
'tasks': getattr(self.__class__, 'tasks'),
'current_task': self.current_task,
'status': self.status,
'completed': self.completed,
'errors': self.errors,
'warnings': self.warnings,
'create_time': self.create_time.isoformat() if self.create_time is not None else None,
'complete_time': self.complete_time.isoformat() if self.complete_time is not None else None,
'_async_status': self._async_status
}
return {key: value for key, value in data.items() if value is not None}
class InvalidChordUsage(Exception): pass
class Chord(Proc):
"""
A special Proc base class that manages a chord of child processes. It saves some
attional state to track child processes and provides methods to control that
state.
It uses a counter approach with atomic updates to trac the number of processed
children.
TODO the joined attribute is not stricly necessary and only serves debugging purposes.
Maybe it should be removed, since it also requires another save.
TODO it is vital that sub classes and children don't miss any calls. This might
not be practical, because in reality processes might even fail to fail.
Attributes:
total_children (int): the number of spawed children, -1 denotes that number was not
saved yet
completed_children (int): the number of completed child procs
joined (bool): true if all children are completed and the join method was already called
"""
total_children = IntField(default=-1)
completed_children = IntField(default=0)
joined = BooleanField(default=False)
meta = {
'abstract': True
}
def spwaned_childred(self, total_children=1):
"""
Subclasses must call this method after all childred have been spawned.
Arguments:
total_children (int): the number of spawned children
"""
self.total_children = total_children
self.save()
self._check_join(children=0)
def completed_child(self):
""" Children must call this, when they completed processig. """
self._check_join(children=1)
def _check_join(self, children):
# incr the counter and get reference values atomically
completed_children, others = self.incr_counter(
'completed_children', children, ['total_children', 'joined'])
total_children, joined = others
self.get_logger().debug(
'Check for join', total_children=total_children,
completed_children=completed_children, joined=joined)
# check the join condition and raise errors if chord is in bad state
if completed_children == total_children:
if not joined:
self.join()
self.joined = True
self.save()
self.get_logger().debug('Chord is joined')
else:
raise InvalidChordUsage('Chord cannot be joined twice.')
elif completed_children > total_children and total_children != -1:
raise InvalidChordUsage('Chord counter is out of limits.')
def join(self):
""" Subclasses might overwrite to do something after all children have completed. """
pass
def incr_counter(self, field, value=1, other_fields=None): def incr_counter(self, field, value=1, other_fields=None):
""" """
Atomically increases the given field by value and return the new value. Atomically increases the given field by value and return the new value.
...@@ -281,31 +380,6 @@ class Proc(Document, metaclass=ProcMetaclass): ...@@ -281,31 +380,6 @@ class Proc(Document, metaclass=ProcMetaclass):
else: else:
return updated_raw[field], [updated_raw[field] for field in other_fields] return updated_raw[field], [updated_raw[field] for field in other_fields]
def block_until_complete(self, interval=0.01):
"""
Reloads the process constrantly until it sees a completed process. Should be
used with care as it can block indefinetly. Just intended for testing purposes.
"""
while not self.completed:
time.sleep(interval)
self.reload()
@property
def json_dict(self) -> dict:
""" A json serializable dictionary representation. """
data = {
'tasks': getattr(self.__class__, 'tasks'),
'current_task': self.current_task,
'status': self.status,
'completed': self.completed,
'errors': self.errors,
'warnings': self.warnings,
'create_time': self.create_time.isoformat() if self.create_time is not None else None,
'complete_time': self.complete_time.isoformat() if self.complete_time is not None else None,
'_async_status': self._async_status
}
return {key: value for key, value in data.items() if value is not None}
def task(func): def task(func):
""" """
...@@ -334,6 +408,15 @@ def task(func): ...@@ -334,6 +408,15 @@ def task(func):
return wrapper return wrapper
def all_subclasses(cls):
""" Helper method to calculate set of all subclasses of a given class. """
return set(cls.__subclasses__()).union(
[s for c in cls.__subclasses__() for s in all_subclasses(c)])
all_proc_cls = {cls.__name__: cls for cls in all_subclasses(Proc)}
""" Name dictionary for all Proc classes. """
@app.task(bind=True, ignore_results=True, max_retries=3) @app.task(bind=True, ignore_results=True, max_retries=3)
def proc_task(task, cls_name, self_id, func_attr): def proc_task(task, cls_name, self_id, func_attr):
""" """
...@@ -345,31 +428,41 @@ def proc_task(task, cls_name, self_id, func_attr): ...@@ -345,31 +428,41 @@ def proc_task(task, cls_name, self_id, func_attr):
""" """
logger = utils.get_logger(__name__, cls=cls_name, id=self_id, func=func_attr) logger = utils.get_logger(__name__, cls=cls_name, id=self_id, func=func_attr)
# get the process class
logger.debug('received process function call') logger.debug('received process function call')
all_cls = Proc.__subclasses__() global all_proc_cls
cls = next((cls for cls in all_cls if cls.__name__ == cls_name), None) cls = all_proc_cls.get(cls_name, None)
if cls is None:
# refind all Proc classes, since more modules might have been imported by now
all_proc_cls = {cls.__name__: cls for cls in all_subclasses(Proc)}
cls = all_proc_cls.get(cls_name, None)
if cls is None: if cls is None:
logger.error('document not a subcass of Proc') logger.error('document not a subcass of Proc')
raise ProcNotRegistered('document %s not a subclass of Proc' % cls_name) raise ProcNotRegistered('document %s not a subclass of Proc' % cls_name)
# get the process instance
try: try:
self = cls.get(self_id) self = cls.get(self_id)
except KeyError as e: except KeyError as e:
logger.warning('called object is missing') logger.warning('called object is missing')
raise task.retry(exc=e, countdown=3) raise task.retry(exc=e, countdown=3)
# get the process function
func = getattr(self, func_attr, None) func = getattr(self, func_attr, None)
if func is None: if func is None:
logger.error('called function not a function of proc class') logger.error('called function not a function of proc class')
self.fail('called function %s is not a function of proc class %s' % (func_attr, cls_name)) self.fail('called function %s is not a function of proc class %s' % (func_attr, cls_name))
return return
# unwrap the process decorator
func = getattr(func, '__process_unwrapped', None) func = getattr(func, '__process_unwrapped', None)
if func is None: if func is None:
logger.error('called function was not decorated with @process') logger.error('called function was not decorated with @process')
self.fail('called function %s was not decorated with @process' % (func_attr, cls_name)) self.fail('called function %s was not decorated with @process' % (func_attr, cls_name))
return return
# call the process function
try: try:
self._async_status = 'RECEIVED-%s' % func.__name__ self._async_status = 'RECEIVED-%s' % func.__name__
func(self) func(self)
......
...@@ -40,7 +40,7 @@ import logging ...@@ -40,7 +40,7 @@ import logging
from nomad import config, files, utils from nomad import config, files, utils
from nomad.repo import RepoCalc from nomad.repo import RepoCalc
from nomad.user import User, me from nomad.user import User, me
from nomad.processing.base import Proc, process, task, PENDING, SUCCESS, FAILURE, RUNNING from nomad.processing.base import Proc, Chord, process, task, PENDING, SUCCESS, FAILURE, RUNNING
from nomad.parsing import LocalBackend, parsers, parser_dict from nomad.parsing import LocalBackend, parsers, parser_dict
from nomad.normalizing import normalizers from nomad.normalizing import normalizers
from nomad.utils import get_logger, lnr from nomad.utils import get_logger, lnr
...@@ -139,7 +139,7 @@ class Calc(Proc): ...@@ -139,7 +139,7 @@ class Calc(Proc):
self.normalizing() self.normalizing()
self.archiving() self.archiving()
finally: finally:
self._upload.calc_proc_completed() self._upload.completed_child()
@task @task
def parsing(self): def parsing(self):
...@@ -182,7 +182,7 @@ class Calc(Proc): ...@@ -182,7 +182,7 @@ class Calc(Proc):
self._parser_backend.write_json(out, pretty=True) self._parser_backend.write_json(out, pretty=True)
class Upload(Proc): class Upload(Chord):
""" """
Represents uploads in the databases. Provides persistence access to the files storage, Represents uploads in the databases. Provides persistence access to the files storage,
and processing state. and processing state.
...@@ -377,10 +377,11 @@ class Upload(Proc): ...@@ -377,10 +377,11 @@ class Upload(Proc):
'exception while matching pot. mainfile', 'exception while matching pot. mainfile',
mainfile=filename, exc_info=e) mainfile=filename, exc_info=e)
# have to save the total_calcs information # have to save the total_calcs information for chord management
self._initiated_parsers = total_calcs self.spwaned_childred(total_calcs)
self.save()
self.calc_proc_completed() def join(self):
self.cleanup()
@task @task
def cleanup(self): def cleanup(self):
...@@ -393,10 +394,6 @@ class Upload(Proc): ...@@ -393,10 +394,6 @@ class Upload(Proc):
upload.close() upload.close()
self.get_logger().debug('closed upload') self.get_logger().debug('closed upload')
def calc_proc_completed(self):
if self._initiated_parsers >= 0 and self.processed_calcs >= self.total_calcs and self.current_task == 'parse_all':
self.cleanup()
@property @property
def processed_calcs(self): def processed_calcs(self):
return Calc.objects(upload_id=self.upload_id, status__in=[SUCCESS, FAILURE]).count() return Calc.objects(upload_id=self.upload_id, status__in=[SUCCESS, FAILURE]).count()
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from datetime import datetime from datetime import datetime
from threading import Thread from threading import Thread, Event
from nomad import files, utils from nomad import files, utils
...@@ -21,7 +21,7 @@ from nomad.processing.data import Upload ...@@ -21,7 +21,7 @@ from nomad.processing.data import Upload
from nomad.utils import get_logger, lnr from nomad.utils import get_logger, lnr
def handle_uploads(quit=False): def handle_uploads(ready=None, quit=False):
""" """
Starts a daemon that will listen to files for new uploads. For each new Starts a daemon that will listen to files for new uploads. For each new
upload it will initiate the processing and save the task in the upload user data, upload it will initiate the processing and save the task in the upload user data,
...@@ -29,6 +29,7 @@ def handle_uploads(quit=False): ...@@ -29,6 +29,7 @@ def handle_uploads(quit=False):
user data. user data.
Arguments: Arguments:
ready (Event): optional, will be set when thread is ready
quit: If true, will only handling one event and stop. Otherwise run forever. quit: If true, will only handling one event and stop. Otherwise run forever.
""" """
...@@ -61,11 +62,15 @@ def handle_uploads(quit=False): ...@@ -61,11 +62,15 @@ def handle_uploads(quit=False):
raise StopIteration raise StopIteration
utils.get_logger(__name__).debug('Start upload put notification handler.') utils.get_logger(__name__).debug('Start upload put notification handler.')
if ready is not None:
ready.set()
handle_upload_put(received_upload_id='provided by decorator') handle_upload_put(received_upload_id='provided by decorator')
def handle_uploads_thread(quit=True): def handle_uploads_thread(quit=True):
""" Same as :func:`handle_uploads` but run in a separate thread. """ """ Same as :func:`handle_uploads` but run in a separate thread. """
thread = Thread(target=lambda: handle_uploads(quit)) ready = Event()
thread = Thread(target=lambda: handle_uploads(ready=ready, quit=quit))
thread.start() thread.start()
ready.wait()
return thread return thread
import pytest import pytest
import logging
from mongoengine import connect from mongoengine import connect
from mongoengine.connection import disconnect from mongoengine.connection import disconnect
...@@ -17,17 +18,17 @@ def celery_config(): ...@@ -17,17 +18,17 @@ def celery_config():
} }
@pytest.fixture(scope='function') @pytest.fixture(scope='session')
def purged_queue(celery_app): def purged_app(celery_session_app):
""" """
Purges all pending tasks of the celery app before test. This is necessary to Purges all pending tasks of the celery app before test. This is necessary to
remove tasks from the queue that might be 'left over' from prior tests. remove tasks from the queue that might be 'left over' from prior tests.
""" """
celery_app.control.purge() celery_session_app.control.purge()
yield yield celery_session_app
@pytest.fixture(scope='function') @pytest.fixture()
def patched_celery(monkeypatch): def patched_celery(monkeypatch):
# There is a bug in celery, which prevents to use the celery_worker for multiple # There is a bug in celery, which prevents to use the celery_worker for multiple
# tests: https://github.com/celery/celery/issues/4088 # tests: https://github.com/celery/celery/issues/4088
...@@ -45,16 +46,27 @@ def patched_celery(monkeypatch): ...@@ -45,16 +46,27 @@ def patched_celery(monkeypatch):
yield yield
@pytest.fixture(scope='function') @pytest.fixture(scope='session')
def worker(patched_celery, purged_queue, celery_worker): def celery_inspect(purged_app):
yield purged_app.control.inspect()
@pytest.fixture()
def worker(patched_celery, celery_inspect, celery_session_worker):
""" """
Extension of the celery_worker fixture that ensures a clean task queue before yielding. Extension of the celery_session_worker fixture that ensures a clean task queue.
""" """
# This wont work with the session_worker, it will already have old/unexecuted tasks
# taken from the queue and might resubmit them. Therefore, purging the queue won't
# help much.
yield yield
# wait until there no more active tasks, to leave clean worker and queues for the next
# test.
while True:
empty = True
for value in celery_inspect.active().values():
empty = empty and len(value) == 0
if empty:
break
@pytest.fixture(scope='function', autouse=True) @pytest.fixture(scope='function', autouse=True)
def mongomock(monkeypatch): def mongomock(monkeypatch):
...@@ -84,3 +96,22 @@ def mocksearch(monkeypatch): ...@@ -84,3 +96,22 @@ def mocksearch(monkeypatch):
monkeypatch.setattr('nomad.repo.RepoCalc.create_from_backend', create_from_backend) monkeypatch.setattr('nomad.repo.RepoCalc.create_from_backend', create_from_backend)
monkeypatch.setattr('nomad.repo.RepoCalc.upload_exists', upload_exists) monkeypatch.setattr('nomad.repo.RepoCalc.upload_exists', upload_exists)
@pytest.fixture(scope='function')
def no_warn(caplog):
yield caplog
for record in caplog.records:
if record.levelname in ['WARNING', 'ERROR', 'CRITICAL']:
assert False, record.msg
@pytest.fixture(scope='function')
def one_error(caplog):
yield caplog
count = 0
for record in caplog.records:
if record.levelname in ['ERROR', 'CRITICAL']:
count += 1
if count > 1:
assert False, "oo many errors"
\ No newline at end of file
import pytest import pytest
from mongoengine import connect, IntField, ReferenceField from mongoengine import connect, IntField, ReferenceField, BooleanField, EmbeddedDocumentField
from mongoengine.connection import disconnect from mongoengine.connection import disconnect
import time import time
import logging import logging
import json import json
import random
import time
from nomad import config from nomad import config
from nomad.processing.base import Proc, process, task, SUCCESS, FAILURE, RUNNING, PENDING from nomad.processing.base import Proc, Chord, process, task, SUCCESS, FAILURE, RUNNING, PENDING
random.seed(0)
def assert_proc(proc, current_task, status=SUCCESS, errors=0, warnings=0): def assert_proc(proc, current_task, status=SUCCESS, errors=0, warnings=0):
...@@ -54,14 +58,13 @@ class FailTasks(Proc): ...@@ -54,14 +58,13 @@ class FailTasks(Proc):
self.fail('fail fail fail') self.fail('fail fail fail')
def test_fail(caplog): def test_fail(one_error):
caplog.set_level(logging.CRITICAL, logger='nomad.processing.base')
p = FailTasks.create() p = FailTasks.create()
p.will_fail() p.will_fail()
assert_proc(p, 'will_fail', FAILURE, errors=1) assert_proc(p, 'will_fail', FAILURE, errors=1)
has_log = False has_log = False
for record in caplog.records: for record in one_error.records:
if record.levelname == 'ERROR': if record.levelname == 'ERROR':
has_log = True has_log = True
assert json.loads(record.msg)['event'] == 'task failed' assert json.loads(record.msg)['event'] == 'task failed'
...@@ -83,7 +86,7 @@ class SimpleProc(Proc): ...@@ -83,7 +86,7 @@ class SimpleProc(Proc):
pass pass
def test_simple_process(worker): def test_simple_process(worker, no_warn):
p = SimpleProc.create() p = SimpleProc.create()
p.process() p.process()
p.block_until_complete() p.block_until_complete()
...@@ -97,31 +100,29 @@ class TaskInProc(Proc): ...@@ -97,31 +100,29 @@ class TaskInProc(Proc):
pass pass
# @pytest.mark.timeout(5) @pytest.mark.timeout(5)
def test_task_as_proc(worker): def test_task_as_proc(worker, no_warn):
p = TaskInProc.create() p = TaskInProc.create()
p.process() p.process()
p.block_until_complete() p.block_until_complete()
assert_proc(p, 'process') assert_proc(p, 'process')
class ParentProc(Proc): class ParentProc(Chord):
children = IntField(default=0)
@process @process
@task @task
def spawn_children(self): def spawn_children(self):
ChildProc.create(parent=self).process() count = 23
for _ in range(0, count):
ChildProc.create(parent=self).process()
self.spwaned_childred(count)
@process
@task @task
def after_children(self): def join(self):
pass pass
def on_child_complete(self):
if self.incr_counter('children') == 1:
self.after_children()
class ChildProc(Proc): class ChildProc(Proc):
parent = ReferenceField(ParentProc) parent = ReferenceField(ParentProc)
...@@ -129,15 +130,15 @@ class ChildProc(Proc): ...@@ -129,15 +130,15 @@ class ChildProc(Proc):
@process @process
@task @task
def process(self):