Commit 7d2c751e authored by Martin Reinecke's avatar Martin Reinecke
Browse files

various tweaks

parent 8d067d40
...@@ -103,26 +103,19 @@ class Field(object): ...@@ -103,26 +103,19 @@ class Field(object):
else: else:
raise TypeError("unknown source type") raise TypeError("unknown source type")
def _parse_domain(self, domain, val=None): @staticmethod
def _parse_domain(domain, val=None):
if domain is None: if domain is None:
if isinstance(val, Field): if isinstance(val, Field):
return val.domain return val.domain
if np.isscalar(val): if np.isscalar(val):
return () # empty domain tuple return () # empty domain tuple
raise TypeError("could not infer domain from value") raise TypeError("could not infer domain from value")
if isinstance(domain, DomainObject):
return (domain,) return utilities.parse_domain(domain)
if not isinstance(domain, tuple): @staticmethod
domain = tuple(domain) def _get_axes_tuple(things_with_shape):
for d in domain:
if not isinstance(d, DomainObject):
raise TypeError(
"Given domain contains something that is not a "
"DomainObject instance.")
return domain
def _get_axes_tuple(self, things_with_shape):
i = 0 i = 0
axes_list = [] axes_list = []
for thing in things_with_shape: for thing in things_with_shape:
...@@ -131,7 +124,8 @@ class Field(object): ...@@ -131,7 +124,8 @@ class Field(object):
i += nax i += nax
return tuple(axes_list) return tuple(axes_list)
def _infer_dtype(self, dtype, val): @staticmethod
def _infer_dtype(dtype, val):
if val is None or dtype is not None: if val is None or dtype is not None:
return np.result_type(dtype, np.float64) return np.result_type(dtype, np.float64)
if isinstance(val, Field): if isinstance(val, Field):
...@@ -140,14 +134,12 @@ class Field(object): ...@@ -140,14 +134,12 @@ class Field(object):
# ---Factory methods--- # ---Factory methods---
@classmethod @staticmethod
def from_random(cls, random_type, domain, dtype=np.float64, **kwargs): def from_random(random_type, domain, dtype=np.float64, **kwargs):
""" Draws a random field with the given parameters. """ Draws a random field with the given parameters.
Parameters Parameters
---------- ----------
cls : class
random_type : String random_type : String
'pm1', 'normal', 'uniform' are the supported arguments for this 'pm1', 'normal', 'uniform' are the supported arguments for this
method. method.
...@@ -240,11 +232,7 @@ class Field(object): ...@@ -240,11 +232,7 @@ class Field(object):
raise ValueError("No space for analysis specified.") raise ValueError("No space for analysis specified.")
if keep_phase_information: if keep_phase_information:
parts_val = self._hermitian_decomposition( parts = self._hermitian_decomposition(self, False)
val=self.val,
preserve_gaussian_variance=False)
parts = [Field(self.domain, part_val, self.dtype, copy=False)
for part_val in parts_val]
else: else:
parts = [self] parts = [self]
...@@ -262,8 +250,8 @@ class Field(object): ...@@ -262,8 +250,8 @@ class Field(object):
else: else:
return parts[0] return parts[0]
@classmethod @staticmethod
def _single_power_analyze(cls, work_field, space_index, binbounds): def _single_power_analyze(work_field, space_index, binbounds):
if not work_field.domain[space_index].harmonic: if not work_field.domain[space_index].harmonic:
raise ValueError("The analyzed space must be harmonic.") raise ValueError("The analyzed space must be harmonic.")
...@@ -276,7 +264,7 @@ class Field(object): ...@@ -276,7 +264,7 @@ class Field(object):
harmonic_domain = work_field.domain[space_index] harmonic_domain = work_field.domain[space_index]
power_domain = PowerSpace(harmonic_partner=harmonic_domain, power_domain = PowerSpace(harmonic_partner=harmonic_domain,
binbounds=binbounds) binbounds=binbounds)
power_spectrum = cls._calculate_power_spectrum( power_spectrum = Field._calculate_power_spectrum(
field_val=work_field.val, field_val=work_field.val,
pdomain=power_domain, pdomain=power_domain,
axes=work_field.domain_axes[space_index]) axes=work_field.domain_axes[space_index])
...@@ -288,15 +276,12 @@ class Field(object): ...@@ -288,15 +276,12 @@ class Field(object):
return Field(domain=result_domain, val=power_spectrum, return Field(domain=result_domain, val=power_spectrum,
dtype=power_spectrum.dtype) dtype=power_spectrum.dtype)
@classmethod @staticmethod
def _calculate_power_spectrum(cls, field_val, pdomain, axes=None): def _calculate_power_spectrum(field_val, pdomain, axes=None):
pindex = pdomain.pindex pindex = pdomain.pindex
if axes is not None: if axes is not None:
pindex = cls._shape_up_pindex( pindex = Field._shape_up_pindex(pindex, field_val.shape, axes)
pindex=pindex,
target_shape=field_val.shape,
axes=axes)
power_spectrum = utilities.bincount_axis(pindex, weights=field_val, power_spectrum = utilities.bincount_axis(pindex, weights=field_val,
axis=axes) axis=axes)
...@@ -407,9 +392,7 @@ class Field(object): ...@@ -407,9 +392,7 @@ class Field(object):
result_list[1] *= spec.imag result_list[1] *= spec.imag
if real_signal: if real_signal:
result_list = [Field(i.domain, self._hermitian_decomposition( result_list = [self._hermitian_decomposition(i, True)[0]
i.val,
preserve_gaussian_variance=True)[0])
for i in result_list] for i in result_list]
if real_power: if real_power:
...@@ -418,13 +401,13 @@ class Field(object): ...@@ -418,13 +401,13 @@ class Field(object):
return result_list[0] + 1j*result_list[1] return result_list[0] + 1j*result_list[1]
@staticmethod @staticmethod
def _hermitian_decomposition(val, preserve_gaussian_variance=False): def _hermitian_decomposition(inp, preserve_gaussian_variance=False):
if preserve_gaussian_variance: if preserve_gaussian_variance:
if not issubclass(val.dtype.type, np.complexfloating): if not issubclass(inp.dtype.type, np.complexfloating):
raise TypeError("complex input field is needed here") raise TypeError("complex input field is needed here")
return (val.real*np.sqrt(2.), val.imag*np.sqrt(2.)) return (inp.real*np.sqrt(2.), inp.imag*np.sqrt(2.))
else: else:
return (val.real.copy(), val.imag.copy()) return (inp.real.copy(), inp.imag.copy())
def _spec_to_rescaler(self, spec, power_space_index): def _spec_to_rescaler(self, spec, power_space_index):
power_space = self.domain[power_space_index] power_space = self.domain[power_space_index]
...@@ -800,12 +783,12 @@ class Field(object): ...@@ -800,12 +783,12 @@ class Field(object):
return self._binary_helper(other, op='__gt__') return self._binary_helper(other, op='__gt__')
def __repr__(self): def __repr__(self):
return "<nifty_core.field>" return "<nifty2go.Field>"
def __str__(self): def __str__(self):
minmax = [self.min(), self.max()] minmax = [self.min(), self.max()]
mean = self.mean() mean = self.mean()
return "nifty_core.field instance\n- domain = " + \ return "nifty2go.Field instance\n- domain = " + \
repr(self.domain) + \ repr(self.domain) + \
"\n- val = " + repr(self.val) + \ "\n- val = " + repr(self.val) + \
"\n - min.,max. = " + str(minmax) + \ "\n - min.,max. = " + str(minmax) + \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment