Commit e42a55b7 authored by Theo Steininger's avatar Theo Steininger

Leveraging the power of DiagonalOperator for

parent 30accef1
......@@ -556,10 +556,9 @@ class Field(Loggable, Versionable, object):
new_val = self.get_val(copy=False)
spaces = utilities.cast_axis_to_tuple(spaces, len(self.domain))
if spaces is None:
spaces = range(len(self.domain))
spaces = utilities.cast_axis_to_tuple(spaces, len(self.domain))
for ind, sp in enumerate(self.domain):
if ind in spaces:
......@@ -571,32 +570,33 @@ class Field(Loggable, Versionable, object):
new_field.set_val(new_val=new_val, copy=False)
return new_field
def dot(self, x=None, bare=False):
if isinstance(x, Field):
assert len(x.domain) == len(self.domain)
for index in xrange(len(self.domain)):
assert x.domain[index] == self.domain[index]
except AssertionError:
raise ValueError(
"domains are incompatible.")
# extract the data from x and try to dot with this
x = x.get_val(copy=False)
def dot(self, x=None, spaces=None, bare=False):
if not isinstance(x, Field):
raise ValueError("The dot-partner must be an instance of " +
"the NIFTy field class")
# Compute the dot respecting the fact of discrete/continous spaces
if bare:
y = self
y = self.weight(spaces=spaces, power=-1)
y = self.weight(power=1)
y = y.get_val(copy=False)
# Cast the input in order to cure dtype and shape differences
x = self.cast(x)
dotted = x.conjugate() * y
y = self
return dotted.sum()
if spaces is None:
x_val = x.get_val(copy=False)
y_val = y.get_val(copy=False)
result = (x_val.conjugate() * y_val).sum()
return result
# create a diagonal operator which is capable of taking care of the
# axes-matching
from nifty.operators.diagonal_operator import DiagonalOperator
diagonal = y.val.conjugate()
diagonalOperator = DiagonalOperator(domain=y.domain,
dotted = diagonalOperator(x, spaces=spaces)
return dotted.sum(spaces=spaces)
def norm(self, q=2):
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment