From bc6cb4bc196074232cc3e019660a5b42b6ddd0ad Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Tue, 12 Sep 2017 16:38:18 +0200 Subject: [PATCH] field simplifications --- nifty2go/field.py | 65 +++++++++++++---------------------------------- 1 file changed, 17 insertions(+), 48 deletions(-) diff --git a/nifty2go/field.py b/nifty2go/field.py index becc8148b..38117375c 100644 --- a/nifty2go/field.py +++ b/nifty2go/field.py @@ -166,8 +166,6 @@ class Field(object): See Also -------- power_synthesize - - """ generator_function = getattr(Random, random_type) @@ -230,9 +228,8 @@ class Field(object): # power_space instances for sp in self.domain: if not sp.harmonic and not isinstance(sp, PowerSpace): - raise TypeError( - "Field has a space in `domain` which is neither " - "harmonic nor a PowerSpace.") + raise TypeError("Field has a space in `domain` which is " + "neither harmonic nor a PowerSpace.") # check if the `spaces` input is valid spaces = utilities.cast_axis_to_tuple(spaces, len(self.domain)) @@ -267,10 +264,8 @@ class Field(object): @classmethod def _single_power_analyze(cls, work_field, space_index, binbounds): - if not work_field.domain[space_index].harmonic: - raise ValueError( - "The analyzed space must be harmonic.") + raise ValueError("The analyzed space must be harmonic.") # Create the target PowerSpace instance: # If the associated signal-space field was real, we extract the @@ -319,9 +314,8 @@ class Field(object): semiscaled_local_shape = [1] * len(target_shape) for i in range(len(axes)): semiscaled_local_shape[axes[i]] = pindex.shape[i] - semiscaled_local_data = pindex.reshape(semiscaled_local_shape) result_obj = np.empty(target_shape, dtype=pindex.dtype) - result_obj[()] = semiscaled_local_data + result_obj[()] = pindex.reshape(semiscaled_local_shape) return result_obj def power_synthesize(self, spaces=None, real_power=True, real_signal=True, @@ -379,18 +373,15 @@ class Field(object): if spaces is None: spaces = list(range(len(self.domain))) - for power_space_index in spaces: - power_space = self.domain[power_space_index] - if not isinstance(power_space, PowerSpace): + for i in spaces: + if not isinstance(self.domain[i], PowerSpace): raise ValueError("A PowerSpace is needed for field " "synthetization.") # create the result domain result_domain = list(self.domain) - for power_space_index in spaces: - power_space = self.domain[power_space_index] - harmonic_domain = power_space.harmonic_partner - result_domain[power_space_index] = harmonic_domain + for i in spaces: + result_domain[i] = self.domain[i].harmonic_partner # create random samples: one or two, depending on whether the # power spectrum is real or complex @@ -422,14 +413,9 @@ class Field(object): for i in result_list] if real_power: - result = result_list[0] - if not issubclass(result_list[0].dtype.type, - np.complexfloating): - result = result.real + return result_list[0] else: - result = result_list[0] + 1j*result_list[1] - - return result + return result_list[0] + 1j*result_list[1] @staticmethod def _hermitian_decomposition(val, preserve_gaussian_variance=False): @@ -497,11 +483,10 @@ class Field(object): def total_volume(self): """ Returns the total volume of all spaces in the domain. """ - volume_tuple = tuple(sp.total_volume for sp in self.domain) - try: - return reduce(lambda x, y: x * y, volume_tuple) - except TypeError: + if len(self.domain) == 0: return 0. + volume_tuple = tuple(sp.total_volume for sp in self.domain) + return reduce(lambda x, y: x * y, volume_tuple) @property def real(self): @@ -517,31 +502,17 @@ class Field(object): # ---Special unary/binary operations--- - def copy(self, domain=None, dtype=None): + def copy(self): """ Returns a full copy of the Field. - If no keyword arguments are given, the returned object will be an - identical copy of the original Field. By explicit specification one is - able to define the domain and the dtype of the returned Field. - - Parameters - ---------- - domain : DomainObject - The new domain the Field shall have. - - dtype : type - The new dtype the Field shall have. + The returned object will be an identical copy of the original Field. Returns ------- out : Field The output object. An identical copy of 'self'. - """ - - if domain is None: - domain = self.domain - return Field(domain=domain, val=self._val, dtype=dtype, copy=True) + return Field(val=self, copy=True) def scalar_weight(self, spaces=None): if np.isscalar(spaces): @@ -699,10 +670,8 @@ class Field(object): axes_list = tuple(self.domain_axes[sp_index] for sp_index in spaces) - try: + if len(axes_list) > 0: axes_list = reduce(lambda x, y: x+y, axes_list) - except TypeError: - axes_list = () # perform the contraction on the data data = getattr(self.val, op)(axis=axes_list) -- GitLab