From cd23003db14bc7dfaa8bf96b1127684fc938407f Mon Sep 17 00:00:00 2001
From: Theo Steininger <theo.steininger@ultimanet.de>
Date: Fri, 17 Nov 2017 01:06:33 +0100
Subject: [PATCH] Pipeline can now fix seed for MagneticFieldFactory

---
 .../magnetic_field/magnetic_field_factory.py               | 5 +++--
 imagine/pipeline.py                                        | 7 +++++--
 2 files changed, 8 insertions(+), 4 deletions(-)

diff --git a/imagine/magnetic_fields/magnetic_field/magnetic_field_factory.py b/imagine/magnetic_fields/magnetic_field/magnetic_field_factory.py
index 0ef68cd..4e36091 100644
--- a/imagine/magnetic_fields/magnetic_field/magnetic_field_factory.py
+++ b/imagine/magnetic_fields/magnetic_field/magnetic_field_factory.py
@@ -130,7 +130,7 @@ class MagneticFieldFactory(Loggable, object):
             parameter_dict[variable_name] = mapped_variable
         return parameter_dict
 
-    def generate(self, variables={}, ensemble_size=1):
+    def generate(self, variables={}, ensemble_size=1, random_seed=None):
         mapped_variables = self._map_variables_to_parameters(variables)
         work_parameters = self.parameter_defaults.copy()
         work_parameters.update(mapped_variables)
@@ -142,7 +142,8 @@ class MagneticFieldFactory(Loggable, object):
         result_magnetic_field = self.magnetic_field_class(
                                               domain=domain,
                                               parameters=work_parameters,
-                                              distribution_strategy='equal')
+                                              distribution_strategy='equal',
+                                              random_seed=random_seed)
         self.logger.debug("Generated magnetic field with work-parameters %s" %
                           work_parameters)
         return result_magnetic_field
diff --git a/imagine/pipeline.py b/imagine/pipeline.py
index fef2ec3..6da8f28 100644
--- a/imagine/pipeline.py
+++ b/imagine/pipeline.py
@@ -56,6 +56,8 @@ class Pipeline(Loggable, object):
 
         self.sample_callback = sample_callback
 
+        self.fixed_random_seed = None
+
     @property
     def observer(self):
         return self._observer
@@ -172,8 +174,9 @@ class Pipeline(Loggable, object):
         # create magnetic field
         self.logger.debug("Creating magnetic field.")
         b_field = self.magnetic_field_factory.generate(
-                                              variables=variables,
-                                              ensemble_size=self.ensemble_size)
+                                          variables=variables,
+                                          ensemble_size=self.ensemble_size,
+                                          random_seed=self.fixed_random_seed)
 
         # create observables
         self.logger.debug("Creating observables.")
-- 
GitLab