Commit d8d18623 authored by Markus Scheidgen's avatar Markus Scheidgen
Browse files

Refactored test and their fixtures.

parent dd4919e7
......@@ -23,7 +23,7 @@ import celery
from celery import Celery, Task
from celery.signals import after_setup_task_logger, after_setup_logger, worker_process_init
from mongoengine import Document, StringField, ListField, DateTimeField, IntField, \
ReferenceField, connect, ValidationError
ReferenceField, connect, ValidationError, BooleanField, EmbeddedDocument
from mongoengine.connection import MongoEngineConnectionError
from mongoengine.base.metaclasses import TopLevelDocumentMetaclass
from pymongo import ReturnDocument
......@@ -252,6 +252,105 @@ class Proc(Document, metaclass=ProcMetaclass):
self.save()
self.get_logger().debug('completed process')
def block_until_complete(self, interval=0.01):
"""
Reloads the process constrantly until it sees a completed process. Should be
used with care as it can block indefinetly. Just intended for testing purposes.
"""
while not self.completed:
time.sleep(interval)
self.reload()
@property
def json_dict(self) -> dict:
""" A json serializable dictionary representation. """
data = {
'tasks': getattr(self.__class__, 'tasks'),
'current_task': self.current_task,
'status': self.status,
'completed': self.completed,
'errors': self.errors,
'warnings': self.warnings,
'create_time': self.create_time.isoformat() if self.create_time is not None else None,
'complete_time': self.complete_time.isoformat() if self.complete_time is not None else None,
'_async_status': self._async_status
}
return {key: value for key, value in data.items() if value is not None}
class InvalidChordUsage(Exception): pass
class Chord(Proc):
"""
A special Proc base class that manages a chord of child processes. It saves some
attional state to track child processes and provides methods to control that
state.
It uses a counter approach with atomic updates to trac the number of processed
children.
TODO the joined attribute is not stricly necessary and only serves debugging purposes.
Maybe it should be removed, since it also requires another save.
TODO it is vital that sub classes and children don't miss any calls. This might
not be practical, because in reality processes might even fail to fail.
Attributes:
total_children (int): the number of spawed children, -1 denotes that number was not
saved yet
completed_children (int): the number of completed child procs
joined (bool): true if all children are completed and the join method was already called
"""
total_children = IntField(default=-1)
completed_children = IntField(default=0)
joined = BooleanField(default=False)
meta = {
'abstract': True
}
def spwaned_childred(self, total_children=1):
"""
Subclasses must call this method after all childred have been spawned.
Arguments:
total_children (int): the number of spawned children
"""
self.total_children = total_children
self.save()
self._check_join(children=0)
def completed_child(self):
""" Children must call this, when they completed processig. """
self._check_join(children=1)
def _check_join(self, children):
# incr the counter and get reference values atomically
completed_children, others = self.incr_counter(
'completed_children', children, ['total_children', 'joined'])
total_children, joined = others
self.get_logger().debug(
'Check for join', total_children=total_children,
completed_children=completed_children, joined=joined)
# check the join condition and raise errors if chord is in bad state
if completed_children == total_children:
if not joined:
self.join()
self.joined = True
self.save()
self.get_logger().debug('Chord is joined')
else:
raise InvalidChordUsage('Chord cannot be joined twice.')
elif completed_children > total_children and total_children != -1:
raise InvalidChordUsage('Chord counter is out of limits.')
def join(self):
""" Subclasses might overwrite to do something after all children have completed. """
pass
def incr_counter(self, field, value=1, other_fields=None):
"""
Atomically increases the given field by value and return the new value.
......@@ -281,31 +380,6 @@ class Proc(Document, metaclass=ProcMetaclass):
else:
return updated_raw[field], [updated_raw[field] for field in other_fields]
def block_until_complete(self, interval=0.01):
"""
Reloads the process constrantly until it sees a completed process. Should be
used with care as it can block indefinetly. Just intended for testing purposes.
"""
while not self.completed:
time.sleep(interval)
self.reload()
@property
def json_dict(self) -> dict:
""" A json serializable dictionary representation. """
data = {
'tasks': getattr(self.__class__, 'tasks'),
'current_task': self.current_task,
'status': self.status,
'completed': self.completed,
'errors': self.errors,
'warnings': self.warnings,
'create_time': self.create_time.isoformat() if self.create_time is not None else None,
'complete_time': self.complete_time.isoformat() if self.complete_time is not None else None,
'_async_status': self._async_status
}
return {key: value for key, value in data.items() if value is not None}
def task(func):
"""
......@@ -334,6 +408,15 @@ def task(func):
return wrapper
def all_subclasses(cls):
""" Helper method to calculate set of all subclasses of a given class. """
return set(cls.__subclasses__()).union(
[s for c in cls.__subclasses__() for s in all_subclasses(c)])
all_proc_cls = {cls.__name__: cls for cls in all_subclasses(Proc)}
""" Name dictionary for all Proc classes. """
@app.task(bind=True, ignore_results=True, max_retries=3)
def proc_task(task, cls_name, self_id, func_attr):
"""
......@@ -345,31 +428,41 @@ def proc_task(task, cls_name, self_id, func_attr):
"""
logger = utils.get_logger(__name__, cls=cls_name, id=self_id, func=func_attr)
# get the process class
logger.debug('received process function call')
all_cls = Proc.__subclasses__()
cls = next((cls for cls in all_cls if cls.__name__ == cls_name), None)
global all_proc_cls
cls = all_proc_cls.get(cls_name, None)
if cls is None:
# refind all Proc classes, since more modules might have been imported by now
all_proc_cls = {cls.__name__: cls for cls in all_subclasses(Proc)}
cls = all_proc_cls.get(cls_name, None)
if cls is None:
logger.error('document not a subcass of Proc')
raise ProcNotRegistered('document %s not a subclass of Proc' % cls_name)
# get the process instance
try:
self = cls.get(self_id)
except KeyError as e:
logger.warning('called object is missing')
raise task.retry(exc=e, countdown=3)
# get the process function
func = getattr(self, func_attr, None)
if func is None:
logger.error('called function not a function of proc class')
self.fail('called function %s is not a function of proc class %s' % (func_attr, cls_name))
return
# unwrap the process decorator
func = getattr(func, '__process_unwrapped', None)
if func is None:
logger.error('called function was not decorated with @process')
self.fail('called function %s was not decorated with @process' % (func_attr, cls_name))
return
# call the process function
try:
self._async_status = 'RECEIVED-%s' % func.__name__
func(self)
......
......@@ -40,7 +40,7 @@ import logging
from nomad import config, files, utils
from nomad.repo import RepoCalc
from nomad.user import User, me
from nomad.processing.base import Proc, process, task, PENDING, SUCCESS, FAILURE, RUNNING
from nomad.processing.base import Proc, Chord, process, task, PENDING, SUCCESS, FAILURE, RUNNING
from nomad.parsing import LocalBackend, parsers, parser_dict
from nomad.normalizing import normalizers
from nomad.utils import get_logger, lnr
......@@ -139,7 +139,7 @@ class Calc(Proc):
self.normalizing()
self.archiving()
finally:
self._upload.calc_proc_completed()
self._upload.completed_child()
@task
def parsing(self):
......@@ -182,7 +182,7 @@ class Calc(Proc):
self._parser_backend.write_json(out, pretty=True)
class Upload(Proc):
class Upload(Chord):
"""
Represents uploads in the databases. Provides persistence access to the files storage,
and processing state.
......@@ -377,10 +377,11 @@ class Upload(Proc):
'exception while matching pot. mainfile',
mainfile=filename, exc_info=e)
# have to save the total_calcs information
self._initiated_parsers = total_calcs
self.save()
self.calc_proc_completed()
# have to save the total_calcs information for chord management
self.spwaned_childred(total_calcs)
def join(self):
self.cleanup()
@task
def cleanup(self):
......@@ -393,10 +394,6 @@ class Upload(Proc):
upload.close()
self.get_logger().debug('closed upload')
def calc_proc_completed(self):
if self._initiated_parsers >= 0 and self.processed_calcs >= self.total_calcs and self.current_task == 'parse_all':
self.cleanup()
@property
def processed_calcs(self):
return Calc.objects(upload_id=self.upload_id, status__in=[SUCCESS, FAILURE]).count()
......
......@@ -13,7 +13,7 @@
# limitations under the License.
from datetime import datetime
from threading import Thread
from threading import Thread, Event
from nomad import files, utils
......@@ -21,7 +21,7 @@ from nomad.processing.data import Upload
from nomad.utils import get_logger, lnr
def handle_uploads(quit=False):
def handle_uploads(ready=None, quit=False):
"""
Starts a daemon that will listen to files for new uploads. For each new
upload it will initiate the processing and save the task in the upload user data,
......@@ -29,6 +29,7 @@ def handle_uploads(quit=False):
user data.
Arguments:
ready (Event): optional, will be set when thread is ready
quit: If true, will only handling one event and stop. Otherwise run forever.
"""
......@@ -61,11 +62,15 @@ def handle_uploads(quit=False):
raise StopIteration
utils.get_logger(__name__).debug('Start upload put notification handler.')
if ready is not None:
ready.set()
handle_upload_put(received_upload_id='provided by decorator')
def handle_uploads_thread(quit=True):
""" Same as :func:`handle_uploads` but run in a separate thread. """
thread = Thread(target=lambda: handle_uploads(quit))
ready = Event()
thread = Thread(target=lambda: handle_uploads(ready=ready, quit=quit))
thread.start()
ready.wait()
return thread
import pytest
import logging
from mongoengine import connect
from mongoengine.connection import disconnect
......@@ -17,17 +18,17 @@ def celery_config():
}
@pytest.fixture(scope='function')
def purged_queue(celery_app):
@pytest.fixture(scope='session')
def purged_app(celery_session_app):
"""
Purges all pending tasks of the celery app before test. This is necessary to
remove tasks from the queue that might be 'left over' from prior tests.
"""
celery_app.control.purge()
yield
celery_session_app.control.purge()
yield celery_session_app
@pytest.fixture(scope='function')
@pytest.fixture()
def patched_celery(monkeypatch):
# There is a bug in celery, which prevents to use the celery_worker for multiple
# tests: https://github.com/celery/celery/issues/4088
......@@ -45,16 +46,27 @@ def patched_celery(monkeypatch):
yield
@pytest.fixture(scope='function')
def worker(patched_celery, purged_queue, celery_worker):
@pytest.fixture(scope='session')
def celery_inspect(purged_app):
yield purged_app.control.inspect()
@pytest.fixture()
def worker(patched_celery, celery_inspect, celery_session_worker):
"""
Extension of the celery_worker fixture that ensures a clean task queue before yielding.
Extension of the celery_session_worker fixture that ensures a clean task queue.
"""
# This wont work with the session_worker, it will already have old/unexecuted tasks
# taken from the queue and might resubmit them. Therefore, purging the queue won't
# help much.
yield
# wait until there no more active tasks, to leave clean worker and queues for the next
# test.
while True:
empty = True
for value in celery_inspect.active().values():
empty = empty and len(value) == 0
if empty:
break
@pytest.fixture(scope='function', autouse=True)
def mongomock(monkeypatch):
......@@ -84,3 +96,22 @@ def mocksearch(monkeypatch):
monkeypatch.setattr('nomad.repo.RepoCalc.create_from_backend', create_from_backend)
monkeypatch.setattr('nomad.repo.RepoCalc.upload_exists', upload_exists)
@pytest.fixture(scope='function')
def no_warn(caplog):
yield caplog
for record in caplog.records:
if record.levelname in ['WARNING', 'ERROR', 'CRITICAL']:
assert False, record.msg
@pytest.fixture(scope='function')
def one_error(caplog):
yield caplog
count = 0
for record in caplog.records:
if record.levelname in ['ERROR', 'CRITICAL']:
count += 1
if count > 1:
assert False, "oo many errors"
\ No newline at end of file
import pytest
from mongoengine import connect, IntField, ReferenceField
from mongoengine import connect, IntField, ReferenceField, BooleanField, EmbeddedDocumentField
from mongoengine.connection import disconnect
import time
import logging
import json
import random
import time
from nomad import config
from nomad.processing.base import Proc, process, task, SUCCESS, FAILURE, RUNNING, PENDING
from nomad.processing.base import Proc, Chord, process, task, SUCCESS, FAILURE, RUNNING, PENDING
random.seed(0)
def assert_proc(proc, current_task, status=SUCCESS, errors=0, warnings=0):
......@@ -54,14 +58,13 @@ class FailTasks(Proc):
self.fail('fail fail fail')
def test_fail(caplog):
caplog.set_level(logging.CRITICAL, logger='nomad.processing.base')
def test_fail(one_error):
p = FailTasks.create()
p.will_fail()
assert_proc(p, 'will_fail', FAILURE, errors=1)
has_log = False
for record in caplog.records:
for record in one_error.records:
if record.levelname == 'ERROR':
has_log = True
assert json.loads(record.msg)['event'] == 'task failed'
......@@ -83,7 +86,7 @@ class SimpleProc(Proc):
pass
def test_simple_process(worker):
def test_simple_process(worker, no_warn):
p = SimpleProc.create()
p.process()
p.block_until_complete()
......@@ -97,31 +100,29 @@ class TaskInProc(Proc):
pass
# @pytest.mark.timeout(5)
def test_task_as_proc(worker):
@pytest.mark.timeout(5)
def test_task_as_proc(worker, no_warn):
p = TaskInProc.create()
p.process()
p.block_until_complete()
assert_proc(p, 'process')
class ParentProc(Proc):
children = IntField(default=0)
class ParentProc(Chord):
@process
@task
def spawn_children(self):
count = 23
for _ in range(0, count):
ChildProc.create(parent=self).process()
@process
self.spwaned_childred(count)
@task
def after_children(self):
def join(self):
pass
def on_child_complete(self):
if self.incr_counter('children') == 1:
self.after_children()
class ChildProc(Proc):
parent = ReferenceField(ParentProc)
......@@ -129,15 +130,15 @@ class ChildProc(Proc):
@process
@task
def process(self):
self.parent.on_child_complete()
time.sleep(random.uniform(0, 0.1))
self.parent.completed_child()
# @pytest.mark.timeout(5)
def test_counter(worker, caplog):
@pytest.mark.timeout(10)
def test_counter(worker, no_warn):
p = ParentProc.create()
p.spawn_children()
p.block_until_complete()
assert_proc(p, 'after_children')
for record in caplog.records:
assert record.levelname not in ['WARNING', 'ERROR', 'CRITICAL']
assert_proc(p, 'join')
assert p.joined
......@@ -20,7 +20,6 @@ reading from the redis result backend, even though all task apperently ended suc
from typing import Generator
import pytest
import logging
from datetime import datetime
from nomad import config, files
......@@ -79,14 +78,13 @@ def assert_processing(upload: Upload):
@pytest.mark.timeout(30)
def test_processing(uploaded_id, worker):
def test_processing(uploaded_id, worker, no_warn):
upload = run_processing(uploaded_id)
assert_processing(upload)
@pytest.mark.parametrize('uploaded_id', [example_files[1]], indirect=True)
def test_processing_doublets(uploaded_id, worker, caplog):
caplog.set_level(logging.CRITICAL)
def test_processing_doublets(uploaded_id, worker, one_error):
upload = run_processing(uploaded_id)
assert upload.status == 'SUCCESS'
......@@ -99,8 +97,7 @@ def test_processing_doublets(uploaded_id, worker, caplog):
@pytest.mark.timeout(30)
def test_process_non_existing(worker, caplog):
caplog.set_level(logging.CRITICAL)
def test_process_non_existing(worker, one_error):
upload = run_processing('__does_not_exist')
assert upload.completed
......@@ -111,9 +108,7 @@ def test_process_non_existing(worker, caplog):
@pytest.mark.parametrize('task', ['extracting', 'parse_all', 'cleanup', 'parsing'])
@pytest.mark.timeout(30)
def test_task_failure(monkeypatch, uploaded_id, worker, task, caplog):
caplog.set_level(logging.CRITICAL)
def test_task_failure(monkeypatch, uploaded_id, worker, task, one_error):
# mock the task method to through exceptions
if hasattr(Upload, task):
cls = Upload
......
......@@ -23,8 +23,8 @@ from tests.test_files import assert_exists # noqa
# import fixtures
from tests.test_files import clear_files, archive_id # noqa pylint: disable=unused-import
from tests.test_normalizing import normalized_vasp_example # noqa pylint: disable=unused-import
from tests.test_parsing import parsed_vasp_example # noqa pylint: disable=unused-import
from tests.test_normalizing import normalized_template_example # noqa pylint: disable=unused-import
from tests.test_parsing import parsed_template_example # noqa pylint: disable=unused-import
from tests.test_repo import example_elastic_calc # noqa pylint: disable=unused-import
......@@ -78,14 +78,14 @@ def assert_upload(upload_json_str, id=None, **kwargs):
return data
def test_no_uploads(client, test_user_auth):
def test_no_uploads(client, test_user_auth, no_warn):
rv = client.get('/uploads', headers=test_user_auth)
assert rv.status_code == 200
assert_uploads(rv.data, count=0)
def test_not_existing_upload(client, test_user_auth):
def test_not_existing_upload(client, test_user_auth, no_warn):
rv = client.get('/uploads/123456789012123456789012', headers=test_user_auth)
assert rv.status_code == 404
......@@ -108,7 +108,7 @@ def test_stale_upload(client, test_user_auth):
assert_upload(rv.data, is_stale=True)
def test_create_upload(client, test_user_auth):
def test_create_upload(client, test_user_auth, no_warn):
rv = client.post('/uploads', headers=test_user_auth)
assert rv.status_code == 200
......@@ -123,7 +123,7 @@ def test_create_upload(client, test_user_auth):
assert_uploads(rv.data, count=1, id=upload_id)
def test_create_upload_with_name(client, test_user_auth):
def test_create_upload_with_name(client, test_user_auth, no_warn):
rv = client.post(
'/uploads', headers=test_user_auth,
data=json.dumps(dict(name='test_name')), content_type='application/json')
......@@ -132,7 +132,7 @@ def test_create_upload_with_name(client, test_user_auth):
assert upload['name'] == 'test_name'
def test_delete_empty_upload(client, test_user_auth):
def test_delete_empty_upload(client, test_user_auth, no_warn):
rv = client.post('/uploads', headers=test_user_auth)
assert rv.status_code == 200
......@@ -145,45 +145,15 @@ def test_delete_empty_upload(client, test_user_auth):
assert rv.status_code == 404
@pytest.mark.parametrize("file", example_files)
@pytest.mark.timeout(30)
def test_upload_to_upload(client, file, test_user_auth):
rv = client.post('/uploads', headers=test_user_auth)
assert rv.status_code == 200
upload = assert_upload(rv.data)
@files.upload_put_handler
def handle_upload_put(received_upload_id: str):
assert upload['upload_id'] == received_upload_id
raise StopIteration
def handle_uploads():
handle_upload_put(received_upload_id='provided by decorator')
handle_uploads_thread = Thread(target=handle_uploads)
handle_uploads_thread.start()
time.sleep(0.1)
upload_url = upload['presigned_url']
cmd = files.create_curl_upload_cmd(upload_url).replace('<ZIPFILE>', file)
subprocess.call(shlex.split(cmd))
handle_uploads_thread.join()
upload_id = upload['upload_id']