Commit 75acf2e7 authored by csongor's avatar csongor
Browse files

Fixed map-ings in nifty_field

parent 080116da
...@@ -550,7 +550,7 @@ class field(object): ...@@ -550,7 +550,7 @@ class field(object):
self.codomain = new_codomain self.codomain = new_codomain
return self.codomain return self.codomain
def weight(self, new_val=None, power=1, overwrite=False): def weight(self, new_val=None, power=1, overwrite=False, spaces=None):
""" """
Returns the field values, weighted with the volume factors to a Returns the field values, weighted with the volume factors to a
given power. The field values will optionally be overwritten. given power. The field values will optionally be overwritten.
...@@ -579,9 +579,12 @@ class field(object): ...@@ -579,9 +579,12 @@ class field(object):
if new_val is None: if new_val is None:
new_val = self.get_val() new_val = self.get_val()
for ind, space in self.domain: if spaces is None:
new_val = space.calc_weigth(new_val, power=power, spaces = range(len(self.get_shape()))
axis=self._axis_list[ind]) for ind in spaces:
new_val = self.domain[ind].calc_weigth(new_val, power=power,
axis=self._axis_list[
ind])
new_field.set_val(new_val=new_val) new_field.set_val(new_val=new_val)
return new_field return new_field
...@@ -645,6 +648,8 @@ class field(object): ...@@ -645,6 +648,8 @@ class field(object):
# Case 3: x is something else # Case 3: x is something else
else: else:
# Cast the input in order to cure dtype and shape differences # Cast the input in order to cure dtype and shape differences
self.field_type_dot("dummy call, reverse spaces iteration")
casted_x = self.cast(x) casted_x = self.cast(x)
# Compute the dot respecting the fact of discrete/continous spaces # Compute the dot respecting the fact of discrete/continous spaces
if not (np.isreal(self.get_val()) or bare): if not (np.isreal(self.get_val()) or bare):
...@@ -652,6 +657,9 @@ class field(object): ...@@ -652,6 +657,9 @@ class field(object):
result = self.get_val().dot(casted_x) result = self.get_val().dot(casted_x)
return np.sum(result, axis=axis) return np.sum(result, axis=axis)
def field_type_dot(self,something):
pass
def vdot(self, *args, **kwargs): def vdot(self, *args, **kwargs):
return self.dot(*args, **kwargs) return self.dot(*args, **kwargs)
...@@ -718,7 +726,7 @@ class field(object): ...@@ -718,7 +726,7 @@ class field(object):
return work_field return work_field
def transform(self, new_domain=None, new_codomain=None, overwrite=False, def transform(self, new_domain=None, new_codomain=None, overwrite=False,
**kwargs): spaces=None, **kwargs):
""" """
Computes the transform of the field using the appropriate conjugate Computes the transform of the field using the appropriate conjugate
transformation. transformation.
...@@ -756,10 +764,14 @@ class field(object): ...@@ -756,10 +764,14 @@ class field(object):
assert (new_domain.check_codomain(new_codomain)) assert (new_domain.check_codomain(new_codomain))
new_val = self.get_val() new_val = self.get_val()
for ind, space in self.domain: if spaces is None:
new_val = space.calc_transform(new_val, codomain=new_domain, spaces = range(len(self.get_shape()))
axis=self._axis_list[ind], **kwargs) else:
for ind in spaces:
new_val = self.domain[ind].calc_transform(new_val,
codomain=new_domain,
axis=self._axis_list[
ind], **kwargs)
if overwrite: if overwrite:
return_field = self return_field = self
return_field.set_codomain(new_codomain=new_codomain, force=True) return_field.set_codomain(new_codomain=new_codomain, force=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment