Commit e51e8754 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

fix Field.vdot according to issue #206

parent c83661d0
Pipeline #23324 passed with stage
in 4 minutes and 44 seconds
......@@ -277,10 +277,7 @@ class Field(object):
if out is not self:
if spaces is None:
spaces = range(len(self.domain))
spaces = utilities.cast_iseq_to_tuple(spaces)
spaces = utilities.parse_spaces(spaces, len(self.domain))
fct = 1.
for ind in spaces:
......@@ -308,44 +305,41 @@ class Field(object):
x : Field
The domain of x must contain `self.domain`
x must live on the same domain as `self`.
spaces : tuple of ints
If the domain of `self` and `x` are not the same, `spaces` defines
which domains of `x` are mapped to those of `self`.
spaces : None, int or tuple of ints (default: None)
The dot product is only carried out over the sub-domains in this
tuple. If None, it is carried out over all sub-domains.
out : float, complex, either scalar or Field
out : float, complex, either scalar (for full dot products)
or Field (for partial dot products)
if not isinstance(x, Field):
raise ValueError("The dot-partner must be an instance of " +
"the NIFTy field class")
# Compute the dot respecting the fact of discrete/continuous spaces
tmp = self.scalar_weight(spaces)
if tmp is None:
fct = 1.
y = self.weight(power=1)
y = self
fct = tmp
if x.domain != self.domain:
raise ValueError("Domain mismatch")
if spaces is None:
return fct*dobj.vdot(y.val, x.val)
ndom = len(self.domain)
spaces = utilities.parse_spaces(spaces, ndom)
if len(spaces) == ndom:
tmp = self.scalar_weight(spaces)
if tmp is None:
fct = 1.
y = self.weight(power=1)
y = self
fct = tmp
spaces = utilities.cast_iseq_to_tuple(spaces)
if spaces == tuple(range(len(self.domain))): # full contraction
return fct*dobj.vdot(y.val, x.val)
raise NotImplementedError("special case for vdot not yet implemented")
active_axes = []
for i in spaces:
active_axes += self.domain.axes[i]
res = 0.
for sl in utilities.get_slice_list(self.shape, active_axes):
res += dobj.vdot(y.val, x.val[sl])
return res*fct
# If we arrive here, we have to do a partial dot product.
# For the moment, do this the explicit, non-optimized way
return (self.conjugate()*x).integrate(spaces=spaces)
def norm(self):
""" Computes the L2-norm of the field values.
......@@ -380,8 +374,8 @@ class Field(object):
def _contraction_helper(self, op, spaces):
if spaces is None:
return getattr(self.val, op)()
spaces = utilities.cast_iseq_to_tuple(spaces)
spaces = utilities.parse_spaces(spaces, len(self.domain))
axes_list = tuple(self.domain.axes[sp_index] for sp_index in spaces)
......@@ -21,7 +21,7 @@ import numpy as np
from ..field import Field
from ..domain_tuple import DomainTuple
from .endomorphic_operator import EndomorphicOperator
from ..utilities import cast_iseq_to_tuple
from .. import utilities
from .. import dobj
......@@ -79,14 +79,9 @@ class DiagonalOperator(EndomorphicOperator):
if diagonal.domain != self._domain:
raise ValueError("domain mismatch")
self._spaces = cast_iseq_to_tuple(spaces)
nspc = len(self._spaces)
if nspc != len(diagonal.domain):
self._spaces = utilities.parse_spaces(spaces, len(self._domain))
if len(self._spaces) != len(diagonal.domain):
raise ValueError("spaces and domain must have the same length")
if nspc > len(self._domain):
raise ValueError("too many spaces")
if nspc > len(set(self._spaces)):
raise ValueError("non-unique space indices")
# if nspc==len(self.diagonal.domain),
# we could do some optimization
for i, j in enumerate(self._spaces):
......@@ -86,10 +86,7 @@ def power_analyze(field, spaces=None, binbounds=None,
dobj.mprint("WARNING: Field has a space in `domain` which is "
"neither harmonic nor a PowerSpace.")
if spaces is None:
spaces = range(len(field.domain))
spaces = utilities.cast_iseq_to_tuple(spaces)
spaces = utilities.parse_spaces(spaces, len(field.domain))
if len(spaces) == 0:
raise ValueError("No space for analysis specified.")
......@@ -117,8 +114,7 @@ def power_analyze(field, spaces=None, binbounds=None,
def power_synthesize_nonrandom(field, spaces=None):
spaces = range(len(field.domain)) if spaces is None \
else utilities.cast_iseq_to_tuple(spaces)
spaces = utilities.parse_spaces(spaces, len(field.domain))
result_domain = list(field.domain)
spec = sqrt(field)
......@@ -65,12 +65,27 @@ def get_slice_list(shape, axes):
yield [slice(None, None)]
def cast_iseq_to_tuple(seq):
if seq is None:
return None
if np.isscalar(seq):
return (int(seq),)
return tuple(int(item) for item in seq)
def safe_cast(tfunc, val):
tmp = tfunc(val)
if val != tmp:
raise ValueError("value changed during cast")
return tmp
def parse_spaces(spaces, maxidx):
maxidx = safe_cast(int, maxidx)
if spaces is None:
return tuple(range(maxidx))
elif np.isscalar(spaces):
spaces = (safe_cast(int, spaces),)
spaces = tuple(safe_cast(int, item) for item in spaces)
tmp = tuple(set(spaces))
if tmp[0] < 0 or tmp[-1] >= maxidx:
raise ValueError("space index out of range")
if len(tmp) != len(spaces):
raise ValueError("multiply defined space indices")
return spaces
def infer_space(domain, space):
......@@ -129,3 +129,11 @@ class Test_Functionality(unittest.TestCase):
f2 = ift.Field.from_random("normal", domain=s, dtype=np.complex128)
assert_allclose(f1.vdot(f2), f1.vdot(f2, spaces=0))
assert_allclose(f1.vdot(f2), np.conj(f2.vdot(f1)))
def test_vdot2(self):
x1 = ift.RGSpace((200,))
x2 = ift.RGSpace((150,))
m = ift.Field((x1, x2), val=.5)
res = m.vdot(m, spaces=1)
ift.dobj.to_global_data(ift.Field(x1, val=.25).val))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment