Commit 47567c7e authored by Theo Steininger's avatar Theo Steininger

Fixed random number generation for 'not'-distributed Fields. Added exception...

Fixed random number generation for 'not'-distributed Fields. Added exception to rg_transforms.py if FFTW changes shapes. Adapted tests.
parent 2573b236
Pipeline #15811 passed with stages
in 15 minutes and 26 seconds
......@@ -127,7 +127,6 @@ class Field(Loggable, Versionable, object):
else:
self.set_val(new_val=val, copy=copy)
def _parse_domain(self, domain, val=None):
if domain is None:
if isinstance(val, Field):
......@@ -240,6 +239,15 @@ class Field(Loggable, Versionable, object):
# random number generator to it
sample = f.get_val(copy=False)
generator_function = getattr(Random, random_type)
comm = sample.comm
size = comm.size
if (sample.distribution_strategy in DISTRIBUTION_STRATEGIES['not'] and
size > 1):
seed = np.random.randint(10000000)
seed = comm.bcast(seed, root=0)
np.random.seed(seed)
sample.apply_generator(
lambda shape: generator_function(dtype=f.dtype,
shape=shape,
......
......@@ -261,6 +261,10 @@ class MPIFFT(Transform):
if p.has_output:
result = p.output_array
if result.shape != val.shape:
raise ValueError("Output shape is different than input shape. "
"Maybe fftw tries to optimize the "
"bit-alignment? Try a different array-size.")
else:
return None
......
import unittest
import numpy as np
from numpy.testing import assert_approx_equal
from nifty import Field,\
......@@ -10,7 +11,7 @@ from itertools import product
from test.common import expand
class ResponseOperator_Tests(unittest.TestCase):
spaces = [RGSpace(100)]
spaces = [RGSpace(128)]
@expand(product(spaces, [0., 5., 1.], [0., 1., .33] ))
def test_property(self, space, sigma, exposure):
......
......@@ -37,7 +37,7 @@ def _get_rtol(tp):
return 1e-5
class SmoothingOperator_Tests(unittest.TestCase):
spaces = [RGSpace(100)]
spaces = [RGSpace(128)]
@expand(product(spaces, [0., .5, 5.], [True, False]))
def test_property(self, space, sigma, log_distances):
......@@ -83,7 +83,7 @@ class SmoothingOperator_Tests(unittest.TestCase):
tt2 = rand2.vdot(op.inverse_adjoint_times(rand1))
assert_approx_equal(tt1, tt2)
@expand(product([100, 200], [1, 0.4], [0., 1., 3.7],
@expand(product([128, 256], [1, 0.4], [0., 1., 3.7],
[np.float64, np.complex128]))
def test_smooth_regular1(self, sz, d, sigma, tp):
tol = _get_rtol(tp)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment