Commit f0ec3ead authored by theos's avatar theos
Browse files

Fixed domain, codomain handling in nifty.field.

parent 09f8ee21
...@@ -187,7 +187,7 @@ class field(object): ...@@ -187,7 +187,7 @@ class field(object):
field_type, datamodel, **kwargs): field_type, datamodel, **kwargs):
# check domain # check domain
self.domain = self._parse_domain(domain=domain) self.domain = self._parse_domain(domain=domain)
self.domain_axes_list = self._get_axes_list(self.domain) self.domain_axes = self._get_axes_tuple(self.domain)
# check codomain # check codomain
if codomain is None: if codomain is None:
...@@ -196,7 +196,7 @@ class field(object): ...@@ -196,7 +196,7 @@ class field(object):
self.codomain = self._parse_codomain(codomain, self.domain) self.codomain = self._parse_codomain(codomain, self.domain)
self.field_type = self._parse_field_type(field_type) self.field_type = self._parse_field_type(field_type)
self.field_type_axes_list = self._get_axes_list(self.field_type) self.field_type_axes = self._get_axes_tuple(self.field_type)
if dtype is None: if dtype is None:
dtype = self._infer_dtype(domain=self.domain, dtype = self._infer_dtype(domain=self.domain,
...@@ -237,7 +237,7 @@ class field(object): ...@@ -237,7 +237,7 @@ class field(object):
dtype = reduce(lambda x, y: np.result_type(x, y), dtype_tuple) dtype = reduce(lambda x, y: np.result_type(x, y), dtype_tuple)
return dtype return dtype
def _get_axes_list(self, things_with_shape): def _get_axes_tuple(self, things_with_shape):
i = 0 i = 0
axes_list = [] axes_list = []
for thing in things_with_shape: for thing in things_with_shape:
...@@ -246,7 +246,7 @@ class field(object): ...@@ -246,7 +246,7 @@ class field(object):
l += [i] l += [i]
i += 1 i += 1
axes_list += [tuple(l)] axes_list += [tuple(l)]
return axes_list return tuple(axes_list)
def _parse_comm(self, comm): def _parse_comm(self, comm):
# check if comm is a string -> the name of comm is given # check if comm is a string -> the name of comm is given
...@@ -512,11 +512,11 @@ class field(object): ...@@ -512,11 +512,11 @@ class field(object):
for ind, sp in enumerate(self.domain): for ind, sp in enumerate(self.domain):
casted_x = sp.complement_cast(casted_x, casted_x = sp.complement_cast(casted_x,
axis=self.domain_axes_list[ind]) axis=self.domain_axes[ind])
for ind, ft in enumerate(self.field_type): for ind, ft in enumerate(self.field_type):
casted_x = ft.complement_cast(casted_x, casted_x = ft.complement_cast(casted_x,
axis=self.field_type_axes_list[ind]) axis=self.field_type_axes[ind])
return casted_x return casted_x
...@@ -647,7 +647,7 @@ class field(object): ...@@ -647,7 +647,7 @@ class field(object):
for ind, sp in enumerate(self.domain): for ind, sp in enumerate(self.domain):
new_val = sp.calc_weight(new_val, new_val = sp.calc_weight(new_val,
power=power, power=power,
axes=self.domain_axes_list[ind]) axes=self.domain_axes[ind])
new_field.set_val(new_val=new_val, copy=False) new_field.set_val(new_val=new_val, copy=False)
return new_field return new_field
...@@ -718,15 +718,15 @@ class field(object): ...@@ -718,15 +718,15 @@ class field(object):
dotted = x.conjugate() * y dotted = x.conjugate() * y
for ind in range(-1, -len(self.field_type_axes_list)-1, -1): for ind in range(-1, -len(self.field_type_axes)-1, -1):
dotted = self.field_type[ind].dot_contraction( dotted = self.field_type[ind].dot_contraction(
dotted, dotted,
axes=self.field_type_axes_list[ind]) axes=self.field_type_axes[ind])
for ind in range(-1, -len(self.domain_axes_list)-1, -1): for ind in range(-1, -len(self.domain_axes)-1, -1):
dotted = self.domain[ind].dot_contraction( dotted = self.domain[ind].dot_contraction(
dotted, dotted,
axes=self.domain_axes_list[ind]) axes=self.domain_axes[ind])
return dotted return dotted
def vdot(self, *args, **kwargs): def vdot(self, *args, **kwargs):
...@@ -792,8 +792,7 @@ class field(object): ...@@ -792,8 +792,7 @@ class field(object):
return work_field return work_field
def transform(self, new_domain=None, new_codomain=None, spaces=None, def transform(self, spaces=None, **kwargs):
**kwargs):
""" """
Computes the transform of the field using the appropriate conjugate Computes the transform of the field using the appropriate conjugate
transformation. transformation.
...@@ -818,37 +817,36 @@ class field(object): ...@@ -818,37 +817,36 @@ class field(object):
Otherwise, nothing is returned. Otherwise, nothing is returned.
""" """
if new_domain is None:
new_domain = self.codomain
# try to recycle the old domain
if new_codomain is None:
try:
new_codomain = self._parse_codomain(self.domain, new_domain)
except ValueError:
new_codomain = self._build_codomain(new_domain)
else:
new_codomain = self._parse_codomain(new_codomain, new_domain)
try: try:
spaces_iterator = iter(spaces) iter(spaces)
except TypeError: except TypeError:
if spaces is None: if spaces is None:
spaces_iterator = xrange(len(self.shape)) spaces = xrange(len(self.domain_axes))
else: else:
spaces_iterator = (spaces, ) spaces = (spaces, )
new_val = self.get_val() new_val = self.get_val()
for ind in spaces_iterator: new_domain = ()
new_codomain = ()
for ind in xrange(len(self.domain)):
if ind in spaces:
sp = self.domain[ind] sp = self.domain[ind]
cosp = self.codomain[ind]
new_val = sp.calc_transform(new_val, new_val = sp.calc_transform(new_val,
codomain=new_domain[ind], codomain=cosp,
axes=self.domain_axes_list[ind], axes=self.domain_axes[ind],
**kwargs) **kwargs)
new_domain += (self.codomain[ind],)
new_codomain += (self.domain[ind],)
else:
new_domain += (self.domain[ind],)
new_codomain += (self.codomain[ind],)
return_field = self.copy_empty(domain=new_domain, return_field = self.copy_empty(domain=new_domain,
codomain=new_codomain) codomain=new_codomain)
return_field.set_val(new_val=new_val, copy=False) return_field.set_val(new_val=new_val, copy=False)
return return_field return return_field
def smooth(self, sigma=0, spaces=None, **kwargs): def smooth(self, sigma=0, spaces=None, **kwargs):
...@@ -882,7 +880,7 @@ class field(object): ...@@ -882,7 +880,7 @@ class field(object):
spaces_iterator = iter(spaces) spaces_iterator = iter(spaces)
except TypeError: except TypeError:
if spaces is None: if spaces is None:
spaces_iterator = xrange(len(self.shape)) spaces_iterator = xrange(len(self.domain))
else: else:
spaces_iterator = (spaces, ) spaces_iterator = (spaces, )
...@@ -891,7 +889,7 @@ class field(object): ...@@ -891,7 +889,7 @@ class field(object):
sp = self.domain[ind] sp = self.domain[ind]
new_val = sp.calc_smooth(new_val, new_val = sp.calc_smooth(new_val,
sigma=sigma, sigma=sigma,
axes=self.domain_axes_list[ind], axes=self.domain_axes[ind],
**kwargs) **kwargs)
new_field.set_val(new_val=new_val, copy=False) new_field.set_val(new_val=new_val, copy=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment