Commit 1e685a3c authored by Lauri Himanen's avatar Lauri Himanen
Browse files

Added PipelineContext object and corresponding custom serializer to make...

Added PipelineContext object and corresponding custom serializer to make working with pipelines easier.
parent 56533e67
......@@ -26,6 +26,7 @@ class Parser(metaclass=ABCMeta):
'''
def __init__(self):
self.name = None
self.domain = 'dft'
self._metainfo_env: Environment = None
......
......@@ -16,7 +16,7 @@ from typing import List, Any, Dict
import logging
import time
import os
from celery import Celery, Task
from celery import Task
from celery.worker.request import Request
from celery.signals import after_setup_task_logger, after_setup_logger, worker_process_init, \
celeryd_after_setup
......@@ -30,6 +30,7 @@ from datetime import datetime
import functools
from nomad import config, utils, infrastructure
from nomad.processing.celeryapp import app
import nomad.patch # pylint: disable=unused-import
......@@ -60,17 +61,6 @@ def capture_worker_name(sender, instance, **kwargs):
worker_hostname = sender
# Celery is configured to use redis as a results backend. Although the results
# are not forwarded within the processing pipeline, celery requires the results
# backend to be configured in order to use chained tasks.
app = Celery('nomad.processing', backend=config.redis_url(), broker=config.rabbitmq_url(),)
app.conf.update(worker_hijack_root_logger=False)
app.conf.update(worker_max_memory_per_child=config.celery.max_memory)
if config.celery.routing == config.CELERY_WORKER_ROUTING:
app.conf.update(worker_direct=True)
app.conf.task_queue_max_priority = 10
CREATED = 'CREATED'
PENDING = 'PENDING'
RUNNING = 'RUNNING'
......
# Copyright 2018 Markus Scheidgen
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an"AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains the Celery configuration.
"""
from celery import Celery
from nomad import config
from nomad.processing import celeryconfig
# Celery is configured to use redis as a results backend. Although the results
# are not forwarded within the processing pipeline, celery requires the results
# backend to be configured in order to use chained tasks.
app = Celery(
'nomad.processing',
backend=config.redis_url(),
broker=config.rabbitmq_url(),
)
# The config is loaded from a standard Celery config file
app.config_from_object(celeryconfig)
app.conf.update(worker_hijack_root_logger=False)
app.conf.update(worker_max_memory_per_child=config.celery.max_memory)
if config.celery.routing == config.CELERY_WORKER_ROUTING:
app.conf.update(worker_direct=True)
app.conf.task_queue_max_priority = 10
......@@ -23,7 +23,7 @@ calculations, and files
.. autoclass:: Upload
'''
from typing import cast, List, Any, Tuple, Iterator, Dict, cast, Iterable
from typing import cast, List, Any, Iterator, Dict, cast, Iterable
from mongoengine import StringField, DateTimeField, DictField, BooleanField, IntField
import logging
from structlog import wrap_logger
......@@ -39,7 +39,7 @@ from nomad.files import PathObject, UploadFiles, ExtractError, ArchiveBasedStagi
from nomad.processing.base import Proc, process, task, PENDING, SUCCESS, FAILURE, PROCESS_CALLED, PROCESS_COMPLETED
from nomad.parsing import parser_dict, match_parser, Backend
from nomad.normalizing import normalizers
from nomad.processing.pipelines import get_pipeline, run_pipelines, Pipeline
from nomad.processing.pipelines import run_pipelines, PipelineContext
def _pack_log_event(logger, method_name, event_dict):
......@@ -872,41 +872,12 @@ class Upload(Proc):
self.staging_upload_files.raw_file_object(path).os_path,
self.staging_upload_files.raw_file_object(stripped_path).os_path))
# def match_mainfiles(self) -> Iterator[Tuple[str, object]]:
# '''
# Generator function that matches all files in the upload to all parsers to
# determine the upload's mainfiles.
# Returns:
# Tuples of mainfile, filename, and parsers
# '''
# directories_with_match: Dict[str, str] = dict()
# upload_files = self.staging_upload_files
# for filename in upload_files.raw_file_manifest():
# self._preprocess_files(filename)
# try:
# parser = match_parser(upload_files.raw_file_object(filename).os_path)
# if parser is not None:
# directory = os.path.dirname(filename)
# if directory in directories_with_match:
# # TODO this might give us the chance to store directory based relationship
# # between calcs for the future?
# pass
# else:
# directories_with_match[directory] = filename
# yield filename, parser
# except Exception as e:
# self.get_logger().error(
# 'exception while matching pot. mainfile',
# mainfile=filename, exc_info=e)
def match_mainfiles(self) -> Iterator[Tuple]:
def match_mainfiles(self) -> Iterator[PipelineContext]:
"""Generator function that iterates over files in an upload and returns
basic information for each found mainfile.
Returns:
Tuple: (filepath, parser, calc_id, worker_hostname, upload_id)
PipelineContext
"""
directories_with_match: Dict[str, str] = dict()
upload_files = self.staging_upload_files
......@@ -922,7 +893,13 @@ class Upload(Proc):
pass
else:
directories_with_match[directory] = filepath
yield filepath, parser, upload_files.calc_id(filepath), self.worker_hostname, self.upload_id
yield PipelineContext(
filepath,
parser.name,
upload_files.calc_id(filepath),
self.upload_id,
self.worker_hostname
)
except Exception as e:
self.get_logger().error(
'exception while matching pot. mainfile',
......
......@@ -21,41 +21,77 @@ import networkx as nx
from celery import chain, group, chord
from celery.exceptions import SoftTimeLimitExceeded
from nomad.processing.base import app, NomadCeleryTask, PROCESS_COMPLETED
from nomad.processing.base import NomadCeleryTask, PROCESS_COMPLETED
from nomad.processing.celeryapp import app
import nomad.processing.data
from nomad import config
import json
from kombu.serialization import register
class Pipeline():
"""Pipeline consists of a list of stages. The pipeline is complete when all
stages are finished.
def my_dumps(obj):
"""Custom JSON encoder function for Celery tasks.
"""
class MyEncoder(json.JSONEncoder):
def default(self, obj): # pylint: disable=E0202
if isinstance(obj, PipelineContext):
return obj.encode()
else:
return json.JSONEncoder.default(self, obj)
return json.dumps(obj, cls=MyEncoder)
def my_loads(obj):
"""Custom JSON decoder function for Celery tasks.
"""
def my_decoder(obj):
if '__type__' in obj:
if obj['__type__'] == '__pipelinecontext__':
return PipelineContext.decode(obj)
return obj
return json.loads(obj, object_hook=my_decoder)
register("myjson", my_dumps, my_loads, content_type='application/x-myjson', content_encoding='utf-8')
class PipelineContext():
"""Convenience class for storing pipeline execution related information.
Provides custom encode/decode functions for JSON serialization with Celery.
"""
def __init__(self, filepath, parser, calc_id, worker_hostname, upload_id, stages=None):
def __init__(self, filepath, parser_name, calc_id, upload_id, worker_hostname):
self.filepath = filepath
self.parser = parser
self.parser_name = parser.name
self.parser_name = parser_name
self.calc_id = calc_id
self.worker_hostname = worker_hostname
self.upload_id = upload_id
if stages is None:
self.stages = []
else:
self.stages = stages
self.worker_hostname = worker_hostname
def add_stage(self, stage):
if len(self.stages) == 0:
if len(stage.dependencies) != 0:
raise ValueError(
"The first stage in a pipeline must not have any dependencies."
)
stage._pipeline = self
stage.index = len(self.stages)
self.stages.append(stage)
def encode(self):
return {
"__type__": "__pipelinecontext__",
"filepath": self.filepath,
"parser_name": self.parser_name,
"calc_id": self.calc_id,
"upload_id": self.upload_id,
"worker_hostname": self.worker_hostname,
}
@staticmethod
def decode(data):
return PipelineContext(
data["filepath"],
data["parser_name"],
data["calc_id"],
data["upload_id"],
data["worker_hostname"],
)
class Stage():
"""Stage comprises of a single python function. After this function is
completed the stage is completed.
completed, the stage is completed.
"""
def __init__(self, name: str, function):
"""
......@@ -67,19 +103,15 @@ class Stage():
communication happens through object persistence in MongoDB. The
function should accept the following arguments:
- filepath: Path of the main file
- parser_name: Name of the identified parser
- calc_id: Calculation id in MongoDB
- upload_id: Upload id in MongoDB
- worker_hostname: Name of the host machine
- context: PipelineContext object
- stage_name: Name of the stage executing the function
- i_stage: The index of this stage in the pipeline
- n_stages: Number of stages in this pipeline
"""
self.name = name
self._function = function
self._pipeline = None
self.index = None
self._pipeline: Pipeline = None
self.index: int = None
self.dependencies: List[Stage] = []
def add_dependency(self, name):
......@@ -91,17 +123,42 @@ class Stage():
def signature(self):
return wrapper.si(
self._function.__name__,
self._pipeline.filepath,
self._pipeline.parser_name,
self._pipeline.calc_id,
self._pipeline.upload_id,
self._pipeline.worker_hostname,
self._pipeline.context,
self.name,
self.index,
len(self._pipeline.stages)
)
class Pipeline():
"""Pipeline consists of a list of stages. The pipeline is complete when all
stages are finished.
"""
def __init__(self, context: PipelineContext):
"""
Args:
context: The working context for this pipeline.
"""
self.context = context
self.stages: List[Stage] = []
def add_stage(self, stage: Stage):
"""Adds a stage to this pipeline. The stages are executec in the order
they are added with this function.
Args:
stage: The stage to be added to this pipeline.
"""
if len(self.stages) == 0:
if len(stage.dependencies) != 0:
raise ValueError(
"The first stage in a pipeline must not have any dependencies."
)
stage._pipeline = self
stage.index = len(self.stages)
self.stages.append(stage)
# This function wraps the function calls made within the stages. Although the
# results are not used in any way, ignore_results is set to False as documented
# in https://docs.celeryproject.org/en/stable/userguide/canvas.html#important-notes
......@@ -109,10 +166,10 @@ class Stage():
bind=True, base=NomadCeleryTask, ignore_results=False, max_retries=3,
acks_late=config.celery.acks_late, soft_time_limit=config.celery.timeout,
time_limit=config.celery.timeout)
def wrapper(task, function_name, filepath, parser_name, calc_id, upload_id, worker_hostname, stage_name, i_stage, n_stages):
def wrapper(task, function_name, context, stage_name, i_stage, n_stages):
# Get the associated calculation
calc = nomad.processing.data.Calc.get(calc_id)
calc = nomad.processing.data.Calc.get(context.calc_id)
logger = calc.get_logger()
# Get the defined function. If does not exist, log error and fail calculation.
......@@ -123,7 +180,7 @@ def wrapper(task, function_name, filepath, parser_name, calc_id, upload_id, work
# Try to execute the stage.
deleted = False
try:
deleted = function(filepath, parser_name, calc_id, upload_id, worker_hostname, stage_name, i_stage, n_stages)
deleted = function(context, stage_name, i_stage, n_stages)
except SoftTimeLimitExceeded as e:
logger.error('exceeded the celery task soft time limit')
calc.fail(e)
......@@ -158,23 +215,28 @@ def empty_task(task, *args, **kwargs):
pass
def comp_process(filepath, parser_name, calc_id, upload_id, worker_hostname, stage_name, i_stage, n_stages):
def comp_process(context, stage_name, i_stage, n_stages):
"""Function for processing computational entries: runs parsing and normalization.
"""
# Process calculation
calc = nomad.processing.data.Calc.get(calc_id)
calc = nomad.processing.data.Calc.get(context.calc_id)
calc.process_calc()
def get_pipeline(filepath, parser, calc_id, worker_hostname, upload_id):
"""Used to fetch a pipeline for a mainfile that has been matched with a parser.
def get_pipeline(context):
"""Used to fetch a pipeline based on a pipeline context. Typically chosen
simply based on a matched parser name that is stored in the context
Args:
context: The context based on which the pipeline is chosen.
Returns:
Pipeline: The pipeline to execute for the given context.
"""
pipeline = Pipeline(filepath, parser, calc_id, worker_hostname, upload_id)
pipeline = Pipeline(context)
# Phonopy pipeline
if parser.name == "parsers/phonopy":
if context.parser_name == "parsers/phonopy":
stage1 = Stage("comp_process_phonopy", comp_process)
stage1.add_dependency("comp_process")
pipeline.add_stage(stage1)
......@@ -186,16 +248,21 @@ def get_pipeline(filepath, parser, calc_id, worker_hostname, upload_id):
return pipeline
def run_pipelines(mainfile_generator):
def run_pipelines(context_generator) -> int:
"""Used to start running pipelines based on the PipelineContext objects
generated by the given generator.
Returns:
The number of pipelines that were started.
"""
# Resolve all pipelines into disconnected dependency trees and run
# each tree in parallel.
stage_dependencies = []
stages = defaultdict(list)
stages: defaultdict = defaultdict(list)
stage_names = set()
n_pipelines = 0
for mainfile_info in mainfile_generator:
pipeline = get_pipeline(*mainfile_info)
for context in context_generator:
pipeline = get_pipeline(context)
n_pipelines += 1
for i_stage, stage in enumerate(pipeline.stages):
......@@ -212,11 +279,11 @@ def run_pipelines(mainfile_generator):
# Create the associated Calc object
nomad.processing.data.Calc.create(
calc_id=pipeline.calc_id,
mainfile=pipeline.filepath,
parser=pipeline.parser_name,
worker_hostname=pipeline.worker_hostname,
upload_id=pipeline.upload_id
calc_id=pipeline.context.calc_id,
mainfile=pipeline.context.filepath,
parser=pipeline.context.parser_name,
worker_hostname=pipeline.context.worker_hostname,
upload_id=pipeline.context.upload_id
)
if n_pipelines != 0:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment