Commit 12a5a3ac authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'master' into line_search

parents 70ef764a 1d10be46
Pipeline #15144 passed with stage
in 6 minutes and 22 seconds
......@@ -466,7 +466,7 @@ class Field(Loggable, Versionable, object):
return result_obj
def power_synthesize(self, spaces=None, real_power=True, real_signal=True,
mean=None, std=None):
mean=None, std=None, distribution_strategy=None):
""" Yields a sampled field with `self`**2 as its power spectrum.
This method draws a Gaussian random field in the harmonic partner
......@@ -541,13 +541,16 @@ class Field(Loggable, Versionable, object):
else:
result_list = [None, None]
if distribution_strategy is None:
distribution_strategy = gc['default_distribution_strategy']
result_list = [self.__class__.from_random(
'normal',
mean=mean,
std=std,
domain=result_domain,
dtype=np.complex,
distribution_strategy=self.distribution_strategy)
distribution_strategy=distribution_strategy)
for x in result_list]
# from now on extract the values from the random fields for further
......
......@@ -373,6 +373,10 @@ class MPIFFT(Transform):
original_shape = inp.shape
inp = inp.reshape(inp.shape[0], 1)
axes = (0, )
if original_shape[0]%2!=0:
raise AttributeError("MPI-FFTs of onedimensional arrays "
"with odd length are currently not supported due to a "
"bug in FFTW. Please use a grid with even length.")
if current_info is None:
transform_shape = list(inp.shape)
......
......@@ -71,7 +71,8 @@ class InvertibleOperatorMixin(object):
def _times(self, x, spaces, x0=None):
if x0 is None:
x0 = Field(self.target, val=0., dtype=x.dtype)
x0 = Field(self.target, val=0., dtype=x.dtype,
distribution_strategy=x.distribution_strategy)
(result, convergence) = self.__inverter(A=self.inverse_times,
b=x,
......@@ -80,7 +81,8 @@ class InvertibleOperatorMixin(object):
def _adjoint_times(self, x, spaces, x0=None):
if x0 is None:
x0 = Field(self.domain, val=0., dtype=x.dtype)
x0 = Field(self.domain, val=0., dtype=x.dtype,
distribution_strategy=x.distribution_strategy)
(result, convergence) = self.__inverter(A=self.adjoint_inverse_times,
b=x,
......@@ -89,7 +91,8 @@ class InvertibleOperatorMixin(object):
def _inverse_times(self, x, spaces, x0=None):
if x0 is None:
x0 = Field(self.domain, val=0., dtype=x.dtype)
x0 = Field(self.domain, val=0., dtype=x.dtype,
distribution_strategy=x.distribution_strategy)
(result, convergence) = self.__inverter(A=self.times,
b=x,
......@@ -98,7 +101,8 @@ class InvertibleOperatorMixin(object):
def _adjoint_inverse_times(self, x, spaces, x0=None):
if x0 is None:
x0 = Field(self.target, val=0., dtype=x.dtype)
x0 = Field(self.target, val=0., dtype=x.dtype,
distribution_strategy=x.distribution_strategy)
(result, convergence) = self.__inverter(A=self.adjoint_times,
b=x,
......
......@@ -37,7 +37,8 @@ class Prober(object):
"""
def __init__(self, domain=None, distribution_strategy=None, probe_count=8,
random_type='pm1', compute_variance=False):
random_type='pm1', probe_dtype=np.float,
compute_variance=False):
self._domain = utilities.parse_domain(domain)
self._distribution_strategy = \
......@@ -45,6 +46,7 @@ class Prober(object):
self._probe_count = self._parse_probe_count(probe_count)
self._random_type = self._parse_random_type(random_type)
self.compute_variance = bool(compute_variance)
self.probe_dtype = np.dtype(probe_dtype)
# ---Properties---
......@@ -104,6 +106,7 @@ class Prober(object):
""" a random-probe generator """
f = Field.from_random(random_type=self.random_type,
domain=self.domain,
dtype=self.probe_dtype,
distribution_strategy=self.distribution_strategy)
uid = np.random.randint(1e18)
return (uid, f)
......
......@@ -63,7 +63,7 @@ class FFTOperatorTests(unittest.TestCase):
assert_equal(res[zc1 * (dim1 // 2), zc2 * (dim2 // 2)], 0.)
@expand(product(["numpy", "fftw", "fftw_mpi"],
[10, 11], [False, True], [False, True],
[12, ], [False, True], [False, True],
[0.1, 1, 3.7],
[np.float64, np.complex128, np.float32, np.complex64]))
def test_fft1D(self, module, dim1, zc1, zc2, d, itp):
......@@ -86,7 +86,7 @@ class FFTOperatorTests(unittest.TestCase):
rtol=tol, atol=tol)
@expand(product(["numpy", "fftw", "fftw_mpi"],
[10, 11], [9, 12], [False, True],
[12, 15], [9, 12], [False, True],
[False, True], [False, True], [False, True], [0.1, 1, 3.7],
[0.4, 1, 2.7],
[np.float64, np.complex128, np.float32, np.complex64]))
......
......@@ -39,4 +39,7 @@ class SpaceSerializationTests(unittest.TestCase):
repo.commit()
assert_equal(space, repo.get('space'))
assert_equal(field, repo.get('field'))
os.remove('test.h5')
try:
os.remove('test.h5')
except OSError:
pass
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment