From cd23003db14bc7dfaa8bf96b1127684fc938407f Mon Sep 17 00:00:00 2001
From: Theo Steininger <theo.steininger@ultimanet.de>
Date: Fri, 17 Nov 2017 01:06:33 +0100
Subject: [PATCH] Pipeline can now fix seed for MagneticFieldFactory
---
.../magnetic_field/magnetic_field_factory.py | 5 +++--
imagine/pipeline.py | 7 +++++--
2 files changed, 8 insertions(+), 4 deletions(-)
diff --git a/imagine/magnetic_fields/magnetic_field/magnetic_field_factory.py b/imagine/magnetic_fields/magnetic_field/magnetic_field_factory.py
index 0ef68cd..4e36091 100644
--- a/imagine/magnetic_fields/magnetic_field/magnetic_field_factory.py
+++ b/imagine/magnetic_fields/magnetic_field/magnetic_field_factory.py
@@ -130,7 +130,7 @@ class MagneticFieldFactory(Loggable, object):
parameter_dict[variable_name] = mapped_variable
return parameter_dict
- def generate(self, variables={}, ensemble_size=1):
+ def generate(self, variables={}, ensemble_size=1, random_seed=None):
mapped_variables = self._map_variables_to_parameters(variables)
work_parameters = self.parameter_defaults.copy()
work_parameters.update(mapped_variables)
@@ -142,7 +142,8 @@ class MagneticFieldFactory(Loggable, object):
result_magnetic_field = self.magnetic_field_class(
domain=domain,
parameters=work_parameters,
- distribution_strategy='equal')
+ distribution_strategy='equal',
+ random_seed=random_seed)
self.logger.debug("Generated magnetic field with work-parameters %s" %
work_parameters)
return result_magnetic_field
diff --git a/imagine/pipeline.py b/imagine/pipeline.py
index fef2ec3..6da8f28 100644
--- a/imagine/pipeline.py
+++ b/imagine/pipeline.py
@@ -56,6 +56,8 @@ class Pipeline(Loggable, object):
self.sample_callback = sample_callback
+ self.fixed_random_seed = None
+
@property
def observer(self):
return self._observer
@@ -172,8 +174,9 @@ class Pipeline(Loggable, object):
# create magnetic field
self.logger.debug("Creating magnetic field.")
b_field = self.magnetic_field_factory.generate(
- variables=variables,
- ensemble_size=self.ensemble_size)
+ variables=variables,
+ ensemble_size=self.ensemble_size,
+ random_seed=self.fixed_random_seed)
# create observables
self.logger.debug("Creating observables.")
--
GitLab