diff --git a/nifty_field.py b/nifty_field.py index eca32fe2e01de3bc37a47cf72b565e9d2bb2e7e1..0a56875d3c2ae516a38df7545a851cb33a524ad1 100644 --- a/nifty_field.py +++ b/nifty_field.py @@ -187,7 +187,7 @@ class field(object): field_type, datamodel, **kwargs): # check domain self.domain = self._parse_domain(domain=domain) - self.domain_axes_list = self._get_axes_list(self.domain) + self.domain_axes = self._get_axes_tuple(self.domain) # check codomain if codomain is None: @@ -196,7 +196,7 @@ class field(object): self.codomain = self._parse_codomain(codomain, self.domain) self.field_type = self._parse_field_type(field_type) - self.field_type_axes_list = self._get_axes_list(self.field_type) + self.field_type_axes = self._get_axes_tuple(self.field_type) if dtype is None: dtype = self._infer_dtype(domain=self.domain, @@ -237,7 +237,7 @@ class field(object): dtype = reduce(lambda x, y: np.result_type(x, y), dtype_tuple) return dtype - def _get_axes_list(self, things_with_shape): + def _get_axes_tuple(self, things_with_shape): i = 0 axes_list = [] for thing in things_with_shape: @@ -246,7 +246,7 @@ class field(object): l += [i] i += 1 axes_list += [tuple(l)] - return axes_list + return tuple(axes_list) def _parse_comm(self, comm): # check if comm is a string -> the name of comm is given @@ -512,11 +512,11 @@ class field(object): for ind, sp in enumerate(self.domain): casted_x = sp.complement_cast(casted_x, - axis=self.domain_axes_list[ind]) + axis=self.domain_axes[ind]) for ind, ft in enumerate(self.field_type): casted_x = ft.complement_cast(casted_x, - axis=self.field_type_axes_list[ind]) + axis=self.field_type_axes[ind]) return casted_x @@ -647,7 +647,7 @@ class field(object): for ind, sp in enumerate(self.domain): new_val = sp.calc_weight(new_val, power=power, - axes=self.domain_axes_list[ind]) + axes=self.domain_axes[ind]) new_field.set_val(new_val=new_val, copy=False) return new_field @@ -718,15 +718,15 @@ class field(object): dotted = x.conjugate() * y - for ind in range(-1, -len(self.field_type_axes_list)-1, -1): + for ind in range(-1, -len(self.field_type_axes)-1, -1): dotted = self.field_type[ind].dot_contraction( dotted, - axes=self.field_type_axes_list[ind]) + axes=self.field_type_axes[ind]) - for ind in range(-1, -len(self.domain_axes_list)-1, -1): + for ind in range(-1, -len(self.domain_axes)-1, -1): dotted = self.domain[ind].dot_contraction( dotted, - axes=self.domain_axes_list[ind]) + axes=self.domain_axes[ind]) return dotted def vdot(self, *args, **kwargs): @@ -792,8 +792,7 @@ class field(object): return work_field - def transform(self, new_domain=None, new_codomain=None, spaces=None, - **kwargs): + def transform(self, spaces=None, **kwargs): """ Computes the transform of the field using the appropriate conjugate transformation. @@ -818,37 +817,36 @@ class field(object): Otherwise, nothing is returned. """ - if new_domain is None: - new_domain = self.codomain - - # try to recycle the old domain - if new_codomain is None: - try: - new_codomain = self._parse_codomain(self.domain, new_domain) - except ValueError: - new_codomain = self._build_codomain(new_domain) - else: - new_codomain = self._parse_codomain(new_codomain, new_domain) try: - spaces_iterator = iter(spaces) + iter(spaces) except TypeError: if spaces is None: - spaces_iterator = xrange(len(self.shape)) + spaces = xrange(len(self.domain_axes)) else: - spaces_iterator = (spaces, ) + spaces = (spaces, ) new_val = self.get_val() - for ind in spaces_iterator: + new_domain = () + new_codomain = () + for ind in xrange(len(self.domain)): + if ind in spaces: sp = self.domain[ind] + cosp = self.codomain[ind] new_val = sp.calc_transform(new_val, - codomain=new_domain[ind], - axes=self.domain_axes_list[ind], + codomain=cosp, + axes=self.domain_axes[ind], **kwargs) + new_domain += (self.codomain[ind],) + new_codomain += (self.domain[ind],) + else: + new_domain += (self.domain[ind],) + new_codomain += (self.codomain[ind],) return_field = self.copy_empty(domain=new_domain, codomain=new_codomain) return_field.set_val(new_val=new_val, copy=False) + return return_field def smooth(self, sigma=0, spaces=None, **kwargs): @@ -882,7 +880,7 @@ class field(object): spaces_iterator = iter(spaces) except TypeError: if spaces is None: - spaces_iterator = xrange(len(self.shape)) + spaces_iterator = xrange(len(self.domain)) else: spaces_iterator = (spaces, ) @@ -891,7 +889,7 @@ class field(object): sp = self.domain[ind] new_val = sp.calc_smooth(new_val, sigma=sigma, - axes=self.domain_axes_list[ind], + axes=self.domain_axes[ind], **kwargs) new_field.set_val(new_val=new_val, copy=False)