Commit 6386588d authored by Martin Reinecke's avatar Martin Reinecke
Browse files

cleanup

parent f9ce2a6e
......@@ -122,20 +122,12 @@ class Field(object):
return tuple(axes_list)
def _infer_dtype(self, dtype, val):
if val is None:
return np.float64 if dtype is None else dtype
if dtype is None:
try:
dtype = val.dtype
except AttributeError:
try:
if val is None:
raise TypeError
dtype = np.result_type(val)
except(TypeError):
dtype = np.dtype(np.float64)
else:
dtype = np.dtype(dtype)
return np.result_type(dtype, np.float)
if isinstance(val,Field):
return val.dtype
return np.result_type(val)
# ---Factory methods---
......@@ -169,17 +161,9 @@ class Field(object):
"""
# create a initially empty field
f = cls(domain=domain, dtype=dtype)
# extract the data from f and apply the appropriate
# random number generator to it
sample = f.get_val(copy=False)
generator_function = getattr(Random, random_type)
sample[()]=generator_function(dtype=f.dtype,
shape=sample.shape,
**kwargs)
f.val=generator_function(dtype=f.dtype, shape=f.shape, **kwargs)
return f
# ---Powerspectral methods---
......@@ -247,15 +231,11 @@ class Field(object):
spaces = list(range(len(self.domain)))
if len(spaces) == 0:
raise ValueError(
"No space for analysis specified.")
raise ValueError("No space for analysis specified.")
if keep_phase_information:
parts_val = self._hermitian_decomposition(
domain=self.domain,
val=self.val,
spaces=spaces,
domain_axes=self.domain_axes,
preserve_gaussian_variance=False)
parts = [self.copy_empty().set_val(part_val, copy=False)
for part_val in parts_val]
......@@ -336,14 +316,12 @@ class Field(object):
@staticmethod
def _shape_up_pindex(pindex, target_shape, axes):
semiscaled_local_shape = [1, ] * len(target_shape)
semiscaled_local_shape = [1] * len(target_shape)
for i in range(len(axes)):
semiscaled_local_shape[axes[i]] = pindex.shape[i]
local_data = pindex
semiscaled_local_data = local_data.reshape(semiscaled_local_shape)
semiscaled_local_data = pindex.reshape(semiscaled_local_shape)
result_obj = np.empty(target_shape, dtype=pindex.dtype)
result_obj[()] = semiscaled_local_data
return result_obj
def power_synthesize(self, spaces=None, real_power=True, real_signal=True,
......@@ -398,7 +376,6 @@ class Field(object):
# check if the `spaces` input is valid
spaces = utilities.cast_axis_to_tuple(spaces, len(self.domain))
if spaces is None:
spaces = list(range(len(self.domain)))
......@@ -417,55 +394,37 @@ class Field(object):
# create random samples: one or two, depending on whether the
# power spectrum is real or complex
if real_power:
result_list = [None]
else:
result_list = [None, None]
result_list = [self.__class__.from_random(
'normal',
mean=mean,
std=std,
domain=result_domain,
dtype=np.complex)
for x in result_list]
for x in range(1 if real_power else 2)]
# from now on extract the values from the random fields for further
# processing without killing the fields.
# if the signal-space field should be real, hermitianize the field
# components
spec = self.val.copy()
spec = np.sqrt(spec)
spec = np.sqrt(self.val)
for power_space_index in spaces:
spec = self._spec_to_rescaler(spec, result_list, power_space_index)
local_rescaler = spec
result_val_list = [x.val for x in result_list]
spec = self._spec_to_rescaler(spec, power_space_index)
# apply the rescaler to the random fields
result_val_list[0] *= local_rescaler.real
result_list[0].val *= spec.real
if not real_power:
result_val_list[1] *= local_rescaler.imag
result_list[1].val *= spec.imag
if real_signal:
result_val_list = [self._hermitian_decomposition(
result_domain,
result_val,
spaces,
result_list[0].domain_axes,
for i in result_list:
i.val = self._hermitian_decomposition(
i.val,
preserve_gaussian_variance=True)[0]
for result_val in result_val_list]
# store the result into the fields
[x.set_val(new_val=y, copy=False) for x, y in
zip(result_list, result_val_list)]
if real_power:
result = result_list[0]
if not issubclass(result_val_list[0].dtype.type,
if not issubclass(result_list[0].dtype.type,
np.complexfloating):
result = result.real
else:
......@@ -474,22 +433,15 @@ class Field(object):
return result
@staticmethod
def _hermitian_decomposition(domain, val, spaces, domain_axes,
preserve_gaussian_variance=False):
h = val.real.copy()
a = 1j * val.imag.copy()
# correct variance
def _hermitian_decomposition(val, preserve_gaussian_variance=False):
if preserve_gaussian_variance:
assert issubclass(val.dtype.type, np.complexfloating),\
"complex input field is needed here"
h *= np.sqrt(2)
a *= np.sqrt(2)
return (h, a)
if not issubclass(val.dtype.type, np.complexfloating):
raise TypeError("complex input field is needed here")
return (val.real*np.sqrt(2.), val.imag*np.sqrt(2.))
else:
return (val.real.copy(), val.imag.copy())
def _spec_to_rescaler(self, spec, result_list, power_space_index):
def _spec_to_rescaler(self, spec, power_space_index):
power_space = self.domain[power_space_index]
local_blow_up = [slice(None)]*len(spec.shape)
......
......@@ -70,17 +70,11 @@ class Test_Functionality(unittest.TestCase):
v = v + 1j*np.random.random(s1+s2)
f1 = Field(ra, val=v, copy=True)
f2 = Field((r1, r2), val=v, copy=True)
h1, a1 = Field._hermitian_decomposition((ra,), f1.val, (0,),
((0, 1,),), preserve)
h2, a2 = Field._hermitian_decomposition((r1, r2), f2.val, (0, 1),
((0,), (1,)), preserve)
h3, a3 = Field._hermitian_decomposition((r1, r2), f2.val, (1, 0),
((0,), (1,)), preserve)
h1, a1 = Field._hermitian_decomposition(f1.val, preserve)
h2, a2 = Field._hermitian_decomposition(f2.val, preserve)
assert_almost_equal(h1, h2)
assert_almost_equal(a1, a2)
assert_almost_equal(h1, h3)
assert_almost_equal(a1, a3)
@expand(product([RGSpace((8,), harmonic=True),
RGSpace((8, 8), harmonic=True, distances=0.123)],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment