From cd23003db14bc7dfaa8bf96b1127684fc938407f Mon Sep 17 00:00:00 2001 From: Theo Steininger <theo.steininger@ultimanet.de> Date: Fri, 17 Nov 2017 01:06:33 +0100 Subject: [PATCH] Pipeline can now fix seed for MagneticFieldFactory --- .../magnetic_field/magnetic_field_factory.py | 5 +++-- imagine/pipeline.py | 7 +++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/imagine/magnetic_fields/magnetic_field/magnetic_field_factory.py b/imagine/magnetic_fields/magnetic_field/magnetic_field_factory.py index 0ef68cd..4e36091 100644 --- a/imagine/magnetic_fields/magnetic_field/magnetic_field_factory.py +++ b/imagine/magnetic_fields/magnetic_field/magnetic_field_factory.py @@ -130,7 +130,7 @@ class MagneticFieldFactory(Loggable, object): parameter_dict[variable_name] = mapped_variable return parameter_dict - def generate(self, variables={}, ensemble_size=1): + def generate(self, variables={}, ensemble_size=1, random_seed=None): mapped_variables = self._map_variables_to_parameters(variables) work_parameters = self.parameter_defaults.copy() work_parameters.update(mapped_variables) @@ -142,7 +142,8 @@ class MagneticFieldFactory(Loggable, object): result_magnetic_field = self.magnetic_field_class( domain=domain, parameters=work_parameters, - distribution_strategy='equal') + distribution_strategy='equal', + random_seed=random_seed) self.logger.debug("Generated magnetic field with work-parameters %s" % work_parameters) return result_magnetic_field diff --git a/imagine/pipeline.py b/imagine/pipeline.py index fef2ec3..6da8f28 100644 --- a/imagine/pipeline.py +++ b/imagine/pipeline.py @@ -56,6 +56,8 @@ class Pipeline(Loggable, object): self.sample_callback = sample_callback + self.fixed_random_seed = None + @property def observer(self): return self._observer @@ -172,8 +174,9 @@ class Pipeline(Loggable, object): # create magnetic field self.logger.debug("Creating magnetic field.") b_field = self.magnetic_field_factory.generate( - variables=variables, - ensemble_size=self.ensemble_size) + variables=variables, + ensemble_size=self.ensemble_size, + random_seed=self.fixed_random_seed) # create observables self.logger.debug("Creating observables.") -- GitLab