Commit f0ec3ead authored by theos's avatar theos
Browse files

Fixed domain, codomain handling in nifty.field.

parent 09f8ee21
......@@ -187,7 +187,7 @@ class field(object):
field_type, datamodel, **kwargs):
# check domain
self.domain = self._parse_domain(domain=domain)
self.domain_axes_list = self._get_axes_list(self.domain)
self.domain_axes = self._get_axes_tuple(self.domain)
# check codomain
if codomain is None:
......@@ -196,7 +196,7 @@ class field(object):
self.codomain = self._parse_codomain(codomain, self.domain)
self.field_type = self._parse_field_type(field_type)
self.field_type_axes_list = self._get_axes_list(self.field_type)
self.field_type_axes = self._get_axes_tuple(self.field_type)
if dtype is None:
dtype = self._infer_dtype(domain=self.domain,
......@@ -237,7 +237,7 @@ class field(object):
dtype = reduce(lambda x, y: np.result_type(x, y), dtype_tuple)
return dtype
def _get_axes_list(self, things_with_shape):
def _get_axes_tuple(self, things_with_shape):
i = 0
axes_list = []
for thing in things_with_shape:
......@@ -246,7 +246,7 @@ class field(object):
l += [i]
i += 1
axes_list += [tuple(l)]
return axes_list
return tuple(axes_list)
def _parse_comm(self, comm):
# check if comm is a string -> the name of comm is given
......@@ -512,11 +512,11 @@ class field(object):
for ind, sp in enumerate(self.domain):
casted_x = sp.complement_cast(casted_x,
axis=self.domain_axes_list[ind])
axis=self.domain_axes[ind])
for ind, ft in enumerate(self.field_type):
casted_x = ft.complement_cast(casted_x,
axis=self.field_type_axes_list[ind])
axis=self.field_type_axes[ind])
return casted_x
......@@ -647,7 +647,7 @@ class field(object):
for ind, sp in enumerate(self.domain):
new_val = sp.calc_weight(new_val,
power=power,
axes=self.domain_axes_list[ind])
axes=self.domain_axes[ind])
new_field.set_val(new_val=new_val, copy=False)
return new_field
......@@ -718,15 +718,15 @@ class field(object):
dotted = x.conjugate() * y
for ind in range(-1, -len(self.field_type_axes_list)-1, -1):
for ind in range(-1, -len(self.field_type_axes)-1, -1):
dotted = self.field_type[ind].dot_contraction(
dotted,
axes=self.field_type_axes_list[ind])
axes=self.field_type_axes[ind])
for ind in range(-1, -len(self.domain_axes_list)-1, -1):
for ind in range(-1, -len(self.domain_axes)-1, -1):
dotted = self.domain[ind].dot_contraction(
dotted,
axes=self.domain_axes_list[ind])
axes=self.domain_axes[ind])
return dotted
def vdot(self, *args, **kwargs):
......@@ -792,8 +792,7 @@ class field(object):
return work_field
def transform(self, new_domain=None, new_codomain=None, spaces=None,
**kwargs):
def transform(self, spaces=None, **kwargs):
"""
Computes the transform of the field using the appropriate conjugate
transformation.
......@@ -818,37 +817,36 @@ class field(object):
Otherwise, nothing is returned.
"""
if new_domain is None:
new_domain = self.codomain
# try to recycle the old domain
if new_codomain is None:
try:
new_codomain = self._parse_codomain(self.domain, new_domain)
except ValueError:
new_codomain = self._build_codomain(new_domain)
else:
new_codomain = self._parse_codomain(new_codomain, new_domain)
try:
spaces_iterator = iter(spaces)
iter(spaces)
except TypeError:
if spaces is None:
spaces_iterator = xrange(len(self.shape))
spaces = xrange(len(self.domain_axes))
else:
spaces_iterator = (spaces, )
spaces = (spaces, )
new_val = self.get_val()
for ind in spaces_iterator:
new_domain = ()
new_codomain = ()
for ind in xrange(len(self.domain)):
if ind in spaces:
sp = self.domain[ind]
cosp = self.codomain[ind]
new_val = sp.calc_transform(new_val,
codomain=new_domain[ind],
axes=self.domain_axes_list[ind],
codomain=cosp,
axes=self.domain_axes[ind],
**kwargs)
new_domain += (self.codomain[ind],)
new_codomain += (self.domain[ind],)
else:
new_domain += (self.domain[ind],)
new_codomain += (self.codomain[ind],)
return_field = self.copy_empty(domain=new_domain,
codomain=new_codomain)
return_field.set_val(new_val=new_val, copy=False)
return return_field
def smooth(self, sigma=0, spaces=None, **kwargs):
......@@ -882,7 +880,7 @@ class field(object):
spaces_iterator = iter(spaces)
except TypeError:
if spaces is None:
spaces_iterator = xrange(len(self.shape))
spaces_iterator = xrange(len(self.domain))
else:
spaces_iterator = (spaces, )
......@@ -891,7 +889,7 @@ class field(object):
sp = self.domain[ind]
new_val = sp.calc_smooth(new_val,
sigma=sigma,
axes=self.domain_axes_list[ind],
axes=self.domain_axes[ind],
**kwargs)
new_field.set_val(new_val=new_val, copy=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment