Commit 75acf2e7 authored by csongor's avatar csongor
Browse files

Fixed map-ings in nifty_field

parent 080116da
......@@ -550,7 +550,7 @@ class field(object):
self.codomain = new_codomain
return self.codomain
def weight(self, new_val=None, power=1, overwrite=False):
def weight(self, new_val=None, power=1, overwrite=False, spaces=None):
"""
Returns the field values, weighted with the volume factors to a
given power. The field values will optionally be overwritten.
......@@ -579,9 +579,12 @@ class field(object):
if new_val is None:
new_val = self.get_val()
for ind, space in self.domain:
new_val = space.calc_weigth(new_val, power=power,
axis=self._axis_list[ind])
if spaces is None:
spaces = range(len(self.get_shape()))
for ind in spaces:
new_val = self.domain[ind].calc_weigth(new_val, power=power,
axis=self._axis_list[
ind])
new_field.set_val(new_val=new_val)
return new_field
......@@ -645,6 +648,8 @@ class field(object):
# Case 3: x is something else
else:
# Cast the input in order to cure dtype and shape differences
self.field_type_dot("dummy call, reverse spaces iteration")
casted_x = self.cast(x)
# Compute the dot respecting the fact of discrete/continous spaces
if not (np.isreal(self.get_val()) or bare):
......@@ -652,6 +657,9 @@ class field(object):
result = self.get_val().dot(casted_x)
return np.sum(result, axis=axis)
def field_type_dot(self,something):
pass
def vdot(self, *args, **kwargs):
return self.dot(*args, **kwargs)
......@@ -718,7 +726,7 @@ class field(object):
return work_field
def transform(self, new_domain=None, new_codomain=None, overwrite=False,
**kwargs):
spaces=None, **kwargs):
"""
Computes the transform of the field using the appropriate conjugate
transformation.
......@@ -756,10 +764,14 @@ class field(object):
assert (new_domain.check_codomain(new_codomain))
new_val = self.get_val()
for ind, space in self.domain:
new_val = space.calc_transform(new_val, codomain=new_domain,
axis=self._axis_list[ind], **kwargs)
if spaces is None:
spaces = range(len(self.get_shape()))
else:
for ind in spaces:
new_val = self.domain[ind].calc_transform(new_val,
codomain=new_domain,
axis=self._axis_list[
ind], **kwargs)
if overwrite:
return_field = self
return_field.set_codomain(new_codomain=new_codomain, force=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment