From 56533e670e131db82ebf08148955f67ff9fa9d03 Mon Sep 17 00:00:00 2001
From: Lauri Himanen <lauri.himanen@gmail.com>
Date: Tue, 14 Apr 2020 14:58:35 +0300
Subject: [PATCH] First version of pipelines getting in shape.

---
 nomad/config.py                               |  11 +-
 nomad/processing/base.py                      |   5 +-
 nomad/processing/data.py                      | 105 ++++---
 nomad/processing/pipelines.py                 | 257 ++++++++++++++++++
 .../docker-compose.override.yml               |   6 +
 .../infrastructure/docker-compose.yml         |   9 +
 6 files changed, 355 insertions(+), 38 deletions(-)
 create mode 100644 nomad/processing/pipelines.py

diff --git a/nomad/config.py b/nomad/config.py
index 37f51ed54f..f096083cc6 100644
--- a/nomad/config.py
+++ b/nomad/config.py
@@ -1,4 +1,4 @@
-# Copyright 2018 Markus Scheidgen
+# Copyright 2018 Markus Scheidgen, empty_task
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -78,11 +78,20 @@ rabbitmq = NomadConfig(
     password='rabbitmq'
 )
 
+redis = NomadConfig(
+    host='localhost',
+    port=6379,
+)
+
 
 def rabbitmq_url():
     return 'pyamqp://%s:%s@%s//' % (rabbitmq.user, rabbitmq.password, rabbitmq.host)
 
 
+def redis_url():
+    return 'redis://%s:%d/0' % (redis.host, redis.port)
+
+
 celery = NomadConfig(
     max_memory=64e6,  # 64 GB
     timeout=1800,  # 1/2 h
diff --git a/nomad/processing/base.py b/nomad/processing/base.py
index b36b2e2c5f..c16422f2e4 100644
--- a/nomad/processing/base.py
+++ b/nomad/processing/base.py
@@ -60,7 +60,10 @@ def capture_worker_name(sender, instance, **kwargs):
     worker_hostname = sender
 
 
-app = Celery('nomad.processing', broker=config.rabbitmq_url())
+# Celery is configured to use redis as a results backend. Although the results
+# are not forwarded within the processing pipeline, celery requires the results
+# backend to be configured in order to use chained tasks.
+app = Celery('nomad.processing', backend=config.redis_url(), broker=config.rabbitmq_url(),)
 app.conf.update(worker_hijack_root_logger=False)
 app.conf.update(worker_max_memory_per_child=config.celery.max_memory)
 if config.celery.routing == config.CELERY_WORKER_ROUTING:
diff --git a/nomad/processing/data.py b/nomad/processing/data.py
index 2166b041df..1b195c9e49 100644
--- a/nomad/processing/data.py
+++ b/nomad/processing/data.py
@@ -23,7 +23,6 @@ calculations, and files
 .. autoclass:: Upload
 
 '''
-
 from typing import cast, List, Any, Tuple, Iterator, Dict, cast, Iterable
 from mongoengine import StringField, DateTimeField, DictField, BooleanField, IntField
 import logging
@@ -37,9 +36,10 @@ from structlog.processors import StackInfoRenderer, format_exc_info, TimeStamper
 
 from nomad import utils, config, infrastructure, search, datamodel
 from nomad.files import PathObject, UploadFiles, ExtractError, ArchiveBasedStagingUploadFiles, PublicUploadFiles, StagingUploadFiles
-from nomad.processing.base import Proc, process, task, PENDING, SUCCESS, FAILURE
+from nomad.processing.base import Proc, process, task, PENDING, SUCCESS, FAILURE, PROCESS_CALLED, PROCESS_COMPLETED
 from nomad.parsing import parser_dict, match_parser, Backend
 from nomad.normalizing import normalizers
+from nomad.processing.pipelines import get_pipeline, run_pipelines, Pipeline
 
 
 def _pack_log_event(logger, method_name, event_dict):
@@ -294,7 +294,6 @@ class Calc(Proc):
             except Exception as e:
                 logger.error('could unload processing results', exc_info=e)
 
-    @process
     def process_calc(self):
         '''
         Processes a new calculation that has no prior records in the mongo, elastic,
@@ -352,13 +351,14 @@ class Calc(Proc):
             self.get_logger().error(
                 'could not write archive after processing failure', exc_info=e)
 
-    def on_process_complete(self, process_name):
-        # the save might be necessary to correctly read the join condition from the db
-        self.save()
-        # in case of error, the process_name might be unknown
-        if process_name == 'process_calc' or process_name == 're_process_calc' or process_name is None:
-            self.upload.reload()
-            self.upload.check_join()
+    # def on_process_complete(self, process_name):
+        # # the save might be necessary to correctly read the join condition from the db
+        # self.save()
+        # # in case of error, the process_name might be unknown
+        # if process_name == 'process_calc' or process_name == 're_process_calc' or process_name is None:
+            # self.get_logger().warning("JOINING NOW")
+            # self.upload.reload()
+            # self.upload.check_join()
 
     @task
     def parsing(self):
@@ -872,34 +872,61 @@ class Upload(Proc):
                     self.staging_upload_files.raw_file_object(path).os_path,
                     self.staging_upload_files.raw_file_object(stripped_path).os_path))
 
-    def match_mainfiles(self) -> Iterator[Tuple[str, object]]:
-        '''
-        Generator function that matches all files in the upload to all parsers to
-        determine the upload's mainfiles.
+    # def match_mainfiles(self) -> Iterator[Tuple[str, object]]:
+        # '''
+        # Generator function that matches all files in the upload to all parsers to
+        # determine the upload's mainfiles.
+
+        # Returns:
+            # Tuples of mainfile, filename, and parsers
+        # '''
+        # directories_with_match: Dict[str, str] = dict()
+        # upload_files = self.staging_upload_files
+        # for filename in upload_files.raw_file_manifest():
+            # self._preprocess_files(filename)
+            # try:
+                # parser = match_parser(upload_files.raw_file_object(filename).os_path)
+                # if parser is not None:
+                    # directory = os.path.dirname(filename)
+                    # if directory in directories_with_match:
+                        # # TODO this might give us the chance to store directory based relationship
+                        # # between calcs for the future?
+                        # pass
+                    # else:
+                        # directories_with_match[directory] = filename
+
+                    # yield filename, parser
+            # except Exception as e:
+                # self.get_logger().error(
+                    # 'exception while matching pot. mainfile',
+                    # mainfile=filename, exc_info=e)
+
+    def match_mainfiles(self) -> Iterator[Tuple]:
+        """Generator function that iterates over files in an upload and returns
+        basic information for each found mainfile.
 
         Returns:
-            Tuples of mainfile, filename, and parsers
-        '''
+            Tuple: (filepath, parser, calc_id, worker_hostname, upload_id)
+        """
         directories_with_match: Dict[str, str] = dict()
         upload_files = self.staging_upload_files
-        for filename in upload_files.raw_file_manifest():
-            self._preprocess_files(filename)
+        for filepath in upload_files.raw_file_manifest():
+            self._preprocess_files(filepath)
             try:
-                parser = match_parser(upload_files.raw_file_object(filename).os_path)
+                parser = match_parser(upload_files.raw_file_object(filepath).os_path)
                 if parser is not None:
-                    directory = os.path.dirname(filename)
+                    directory = os.path.dirname(filepath)
                     if directory in directories_with_match:
                         # TODO this might give us the chance to store directory based relationship
                         # between calcs for the future?
                         pass
                     else:
-                        directories_with_match[directory] = filename
-
-                    yield filename, parser
+                        directories_with_match[directory] = filepath
+                    yield filepath, parser, upload_files.calc_id(filepath), self.worker_hostname, self.upload_id
             except Exception as e:
                 self.get_logger().error(
                     'exception while matching pot. mainfile',
-                    mainfile=filename, exc_info=e)
+                    mainfile=filepath, exc_info=e)
 
     @task
     def parse_all(self):
@@ -912,18 +939,21 @@ class Upload(Proc):
         with utils.timer(
                 logger, 'upload extracted', step='matching',
                 upload_size=self.upload_files.size):
-            for filename, parser in self.match_mainfiles():
-                calc = Calc.create(
-                    calc_id=self.upload_files.calc_id(filename),
-                    mainfile=filename, parser=parser.name,
-                    worker_hostname=self.worker_hostname,
-                    upload_id=self.upload_id)
 
-                calc.process_calc()
+            # Tell Upload that a process has been started.
+            self.current_process = "process"
+            self.process_status = PROCESS_CALLED
+            self.save()
 
-    def on_process_complete(self, process_name):
-        if process_name == 'process_upload' or process_name == 're_process_upload':
-            self.check_join()
+            # Start running all pipelines
+            n_pipelines = run_pipelines(self.match_mainfiles())
+
+            # If the upload has not spawned any pipelines, tell it that it is
+            # finished and perform cleanup
+            if n_pipelines == 0:
+                self.process_status = PROCESS_COMPLETED
+                self.save()
+                self.cleanup()
 
     def check_join(self):
         '''
@@ -936,8 +966,11 @@ class Upload(Proc):
         '''
         total_calcs = self.total_calcs
         processed_calcs = self.processed_calcs
+        logger = self.get_logger()
+
+        logger.warning("Checking join: {}/{}".format(processed_calcs, total_calcs))
 
-        self.get_logger().debug('check join', processed_calcs=processed_calcs, total_calcs=total_calcs)
+        logger.warning('check join', processed_calcs=processed_calcs, total_calcs=total_calcs)
         # check if process is not running anymore, i.e. not still spawining new processes to join
         # check the join condition, i.e. all calcs have been processed
         if not self.process_running and processed_calcs >= total_calcs:
@@ -946,7 +979,7 @@ class Upload(Proc):
                 {'_id': self.upload_id, 'joined': {'$ne': True}},
                 {'$set': {'joined': True}})
             if modified_upload is not None:
-                self.get_logger().debug('join')
+                logger.debug('join')
                 self.cleanup()
             else:
                 # the join was already done due to a prior call
diff --git a/nomad/processing/pipelines.py b/nomad/processing/pipelines.py
new file mode 100644
index 0000000000..f890514f58
--- /dev/null
+++ b/nomad/processing/pipelines.py
@@ -0,0 +1,257 @@
+# Copyright 2018 Markus Scheidgen
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an"AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This module contains the objects and setups for running pipelines.
+"""
+from typing import List
+from collections import defaultdict
+import networkx as nx
+from celery import chain, group, chord
+from celery.exceptions import SoftTimeLimitExceeded
+
+from nomad.processing.base import app, NomadCeleryTask, PROCESS_COMPLETED
+import nomad.processing.data
+from nomad import config
+
+
+class Pipeline():
+    """Pipeline consists of a list of stages. The pipeline is complete when all
+    stages are finished.
+    """
+    def __init__(self, filepath, parser, calc_id, worker_hostname, upload_id, stages=None):
+        self.filepath = filepath
+        self.parser = parser
+        self.parser_name = parser.name
+        self.calc_id = calc_id
+        self.worker_hostname = worker_hostname
+        self.upload_id = upload_id
+        if stages is None:
+            self.stages = []
+        else:
+            self.stages = stages
+
+    def add_stage(self, stage):
+        if len(self.stages) == 0:
+            if len(stage.dependencies) != 0:
+                raise ValueError(
+                    "The first stage in a pipeline must not have any dependencies."
+                )
+        stage._pipeline = self
+        stage.index = len(self.stages)
+        self.stages.append(stage)
+
+
+class Stage():
+    """Stage comprises of a single python function. After this function is
+    completed the stage is completed.
+    """
+    def __init__(self, name: str, function):
+        """
+        Args:
+            name: Name of the stage. The name is used in resolving stage
+                dependencies.
+            function: A regular python function that will be executed during
+                this stage. The function should not return any values as all
+                communication happens through object persistence in MongoDB. The
+                function should accept the following arguments:
+
+                    - filepath: Path of the main file
+                    - parser_name: Name of the identified parser
+                    - calc_id: Calculation id in MongoDB
+                    - upload_id: Upload id in MongoDB
+                    - worker_hostname: Name of the host machine
+                    - stage_name: Name of the stage executing the function
+                    - i_stage: The index of this stage in the pipeline
+                    - n_stages: Number of stages in this pipeline
+        """
+        self.name = name
+        self._function = function
+        self._pipeline = None
+        self.index = None
+        self.dependencies: List[Stage] = []
+
+    def add_dependency(self, name):
+        self.dependencies.append(name)
+
+    # def run(self, filepath, parser_name, calc_id, upload_id, worker_hostname, stage_name, i_stage, n_stages):
+        # wrapper.delay(self._function.__name__, filepath, parser_name, calc_id, upload_id, worker_hostname, stage_name, i_stage, n_stages)
+
+    def signature(self):
+        return wrapper.si(
+            self._function.__name__,
+            self._pipeline.filepath,
+            self._pipeline.parser_name,
+            self._pipeline.calc_id,
+            self._pipeline.upload_id,
+            self._pipeline.worker_hostname,
+            self.name,
+            self.index,
+            len(self._pipeline.stages)
+        )
+
+
+# This function wraps the function calls made within the stages. Although the
+# results are not used in any way, ignore_results is set to False as documented
+# in https://docs.celeryproject.org/en/stable/userguide/canvas.html#important-notes
+@app.task(
+    bind=True, base=NomadCeleryTask, ignore_results=False, max_retries=3,
+    acks_late=config.celery.acks_late, soft_time_limit=config.celery.timeout,
+    time_limit=config.celery.timeout)
+def wrapper(task, function_name, filepath, parser_name, calc_id, upload_id, worker_hostname, stage_name, i_stage, n_stages):
+
+    # Get the associated calculation
+    calc = nomad.processing.data.Calc.get(calc_id)
+    logger = calc.get_logger()
+
+    # Get the defined function. If does not exist, log error and fail calculation.
+    function = globals().get(function_name, None)
+    if function is None:
+        calc.fail('Could not find the function associated with the stage.')
+
+    # Try to execute the stage.
+    deleted = False
+    try:
+        deleted = function(filepath, parser_name, calc_id, upload_id, worker_hostname, stage_name, i_stage, n_stages)
+    except SoftTimeLimitExceeded as e:
+        logger.error('exceeded the celery task soft time limit')
+        calc.fail(e)
+    except Exception as e:
+        calc.fail(e)
+    except SystemExit as e:
+        calc.fail(e)
+    finally:
+        if deleted is None or not deleted:
+            calc.save()
+
+    # For last stage, inform upload that we are finished.
+    if i_stage == n_stages - 1:
+        # The save might be necessary to correctly read the join condition from the db
+        calc.save()
+
+        # Inform upload that we are finished
+        upload = calc.upload
+        upload.process_status = PROCESS_COMPLETED
+        upload.save()
+        upload.reload()
+        upload.check_join()
+
+
+@app.task(
+    bind=True, base=NomadCeleryTask, ignore_results=False, max_retries=3,
+    acks_late=config.celery.acks_late, soft_time_limit=config.celery.timeout,
+    time_limit=config.celery.timeout)
+def empty_task(task, *args, **kwargs):
+    """Empty dummy task.
+    """
+    pass
+
+
+def comp_process(filepath, parser_name, calc_id, upload_id, worker_hostname, stage_name, i_stage, n_stages):
+    """Function for processing computational entries: runs parsing and normalization.
+    """
+    # Process calculation
+    calc = nomad.processing.data.Calc.get(calc_id)
+    calc.process_calc()
+
+
+def get_pipeline(filepath, parser, calc_id, worker_hostname, upload_id):
+    """Used to fetch a pipeline for a mainfile that has been matched with a parser.
+
+    Args:
+    """
+    pipeline = Pipeline(filepath, parser, calc_id, worker_hostname, upload_id)
+
+    # Phonopy pipeline
+    if parser.name == "parsers/phonopy":
+        stage1 = Stage("comp_process_phonopy", comp_process)
+        stage1.add_dependency("comp_process")
+        pipeline.add_stage(stage1)
+    # DFT pipeline
+    else:
+        stage1 = Stage("comp_process", comp_process)
+        pipeline.add_stage(stage1)
+
+    return pipeline
+
+
+def run_pipelines(mainfile_generator):
+
+    # Resolve all pipelines into disconnected dependency trees and run
+    # each tree in parallel.
+    stage_dependencies = []
+    stages = defaultdict(list)
+    stage_names = set()
+    n_pipelines = 0
+    for mainfile_info in mainfile_generator:
+        pipeline = get_pipeline(*mainfile_info)
+        n_pipelines += 1
+        for i_stage, stage in enumerate(pipeline.stages):
+
+            # Store stage names to be used as nodes
+            stage_names.add(stage.name)
+
+            # Store dependencies to be used as edges
+            for dependency in stage.dependencies:
+                stage_dependencies.append((stage.name, dependency))
+            stages[stage.name].append(stage)
+
+            # Start running first stage: it does not have any dependencies
+            if i_stage == 0:
+
+                # Create the associated Calc object
+                nomad.processing.data.Calc.create(
+                    calc_id=pipeline.calc_id,
+                    mainfile=pipeline.filepath,
+                    parser=pipeline.parser_name,
+                    worker_hostname=pipeline.worker_hostname,
+                    upload_id=pipeline.upload_id
+                )
+
+    if n_pipelines != 0:
+        # Resolve all independent dependency trees
+        dependency_graph = nx.DiGraph()
+        dependency_graph.add_nodes_from(stage_names)
+        dependency_graph.add_edges_from(stage_dependencies)
+        dependency_trees = nx.weakly_connected_components(dependency_graph)
+
+        # Form chains for each independent tree.
+        chains = []
+        for tree_nodes in dependency_trees:
+            tree = dependency_graph.subgraph(tree_nodes).copy()
+            sorted_nodes = nx.topological_sort(tree)
+            groups = []
+            for node in reversed(list(sorted_nodes)):
+
+                # Group all tasks for a stage
+                tasks = stages[node]
+                task_signatures = []
+                for task in tasks:
+                    task_signatures.append(task.signature())
+                task_group = group(*task_signatures)
+
+                # Celery does not allow chaining groups. To simulate this
+                # behaviour, we instead wrap the groups inside chords which can be
+                # chained. The callback is a dummy function that does nothing.
+                groups.append(chord(task_group, body=empty_task.si()))
+
+            # Form a chain of stages for this tree
+            stage_chain = chain(*groups)
+            chains.append(stage_chain)
+
+        # This is is the final group that will start executing all independent stage trees parallelly.
+        final_group = group(*chains)
+        final_group.delay()
+
+    return n_pipelines
diff --git a/ops/docker-compose/infrastructure/docker-compose.override.yml b/ops/docker-compose/infrastructure/docker-compose.override.yml
index cb6b7a753f..5d27a7a6be 100644
--- a/ops/docker-compose/infrastructure/docker-compose.override.yml
+++ b/ops/docker-compose/infrastructure/docker-compose.override.yml
@@ -26,6 +26,12 @@ services:
         ports:
             - 5672:5672
 
+    # backend for celery
+    redis:
+        restart: 'no'
+        ports:
+            - 6379:6379
+
     # the search engine
     elastic:
         restart: 'no'
diff --git a/ops/docker-compose/infrastructure/docker-compose.yml b/ops/docker-compose/infrastructure/docker-compose.yml
index 5912a3a259..7a00fb1537 100644
--- a/ops/docker-compose/infrastructure/docker-compose.yml
+++ b/ops/docker-compose/infrastructure/docker-compose.yml
@@ -46,6 +46,14 @@ services:
         volumes:
             - nomad_rabbitmq:/var/lib/rabbitmq
 
+    # result backend for celery
+    redis:
+        restart: unless-stopped
+        image: redis:5.0.8-alpine
+        container_name: nomad_redis
+        volumes:
+            - nomad_redis:/data
+
     # the search engine
     elastic:
         restart: always
@@ -71,3 +79,4 @@ volumes:
     nomad_mongo:
     nomad_elastic:
     nomad_rabbitmq:
+    nomad_redis:
-- 
GitLab