Commit c97e6814 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'nifty2go' into fun_with_operators

parents af4c67aa e51e8754
...@@ -277,10 +277,7 @@ class Field(object): ...@@ -277,10 +277,7 @@ class Field(object):
if out is not self: if out is not self:
out.copy_content_from(self) out.copy_content_from(self)
if spaces is None: spaces = utilities.parse_spaces(spaces, len(self.domain))
spaces = range(len(self.domain))
else:
spaces = utilities.cast_iseq_to_tuple(spaces)
fct = 1. fct = 1.
for ind in spaces: for ind in spaces:
...@@ -309,44 +306,41 @@ class Field(object): ...@@ -309,44 +306,41 @@ class Field(object):
Parameters Parameters
---------- ----------
x : Field x : Field
The domain of x must contain `self.domain` x must live on the same domain as `self`.
spaces : tuple of ints spaces : None, int or tuple of ints (default: None)
If the domain of `self` and `x` are not the same, `spaces` defines The dot product is only carried out over the sub-domains in this
which domains of `x` are mapped to those of `self`. tuple. If None, it is carried out over all sub-domains.
Returns Returns
------- -------
out : float, complex, either scalar or Field out : float, complex, either scalar (for full dot products)
or Field (for partial dot products)
""" """
if not isinstance(x, Field): if not isinstance(x, Field):
raise ValueError("The dot-partner must be an instance of " + raise ValueError("The dot-partner must be an instance of " +
"the NIFTy field class") "the NIFTy field class")
# Compute the dot respecting the fact of discrete/continuous spaces if x.domain != self.domain:
tmp = self.scalar_weight(spaces) raise ValueError("Domain mismatch")
if tmp is None:
fct = 1.
y = self.weight(power=1)
else:
y = self
fct = tmp
if spaces is None: ndom = len(self.domain)
return fct*dobj.vdot(y.val, x.val) spaces = utilities.parse_spaces(spaces, ndom)
if len(spaces) == ndom:
tmp = self.scalar_weight(spaces)
if tmp is None:
fct = 1.
y = self.weight(power=1)
else:
y = self
fct = tmp
spaces = utilities.cast_iseq_to_tuple(spaces)
if spaces == tuple(range(len(self.domain))): # full contraction
return fct*dobj.vdot(y.val, x.val) return fct*dobj.vdot(y.val, x.val)
raise NotImplementedError("special case for vdot not yet implemented") # If we arrive here, we have to do a partial dot product.
active_axes = [] # For the moment, do this the explicit, non-optimized way
for i in spaces: return (self.conjugate()*x).integrate(spaces=spaces)
active_axes += self.domain.axes[i]
res = 0.
for sl in utilities.get_slice_list(self.shape, active_axes):
res += dobj.vdot(y.val, x.val[sl])
return res*fct
def norm(self): def norm(self):
""" Computes the L2-norm of the field values. """ Computes the L2-norm of the field values.
...@@ -381,8 +375,8 @@ class Field(object): ...@@ -381,8 +375,8 @@ class Field(object):
def _contraction_helper(self, op, spaces): def _contraction_helper(self, op, spaces):
if spaces is None: if spaces is None:
return getattr(self.val, op)() return getattr(self.val, op)()
else:
spaces = utilities.cast_iseq_to_tuple(spaces) spaces = utilities.parse_spaces(spaces, len(self.domain))
axes_list = tuple(self.domain.axes[sp_index] for sp_index in spaces) axes_list = tuple(self.domain.axes[sp_index] for sp_index in spaces)
......
...@@ -21,7 +21,7 @@ import numpy as np ...@@ -21,7 +21,7 @@ import numpy as np
from ..field import Field from ..field import Field
from ..domain_tuple import DomainTuple from ..domain_tuple import DomainTuple
from .endomorphic_operator import EndomorphicOperator from .endomorphic_operator import EndomorphicOperator
from ..utilities import cast_iseq_to_tuple from .. import utilities
from .. import dobj from .. import dobj
...@@ -72,14 +72,9 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -72,14 +72,9 @@ class DiagonalOperator(EndomorphicOperator):
if diagonal.domain != self._domain: if diagonal.domain != self._domain:
raise ValueError("domain mismatch") raise ValueError("domain mismatch")
else: else:
self._spaces = cast_iseq_to_tuple(spaces) self._spaces = utilities.parse_spaces(spaces, len(self._domain))
nspc = len(self._spaces) if len(self._spaces) != len(diagonal.domain):
if nspc != len(diagonal.domain):
raise ValueError("spaces and domain must have the same length") raise ValueError("spaces and domain must have the same length")
if nspc > len(self._domain):
raise ValueError("too many spaces")
if nspc > len(set(self._spaces)):
raise ValueError("non-unique space indices")
# if nspc==len(self.diagonal.domain), # if nspc==len(self.diagonal.domain),
# we could do some optimization # we could do some optimization
for i, j in enumerate(self._spaces): for i, j in enumerate(self._spaces):
......
...@@ -86,10 +86,7 @@ def power_analyze(field, spaces=None, binbounds=None, ...@@ -86,10 +86,7 @@ def power_analyze(field, spaces=None, binbounds=None,
dobj.mprint("WARNING: Field has a space in `domain` which is " dobj.mprint("WARNING: Field has a space in `domain` which is "
"neither harmonic nor a PowerSpace.") "neither harmonic nor a PowerSpace.")
if spaces is None: spaces = utilities.parse_spaces(spaces, len(field.domain))
spaces = range(len(field.domain))
else:
spaces = utilities.cast_iseq_to_tuple(spaces)
if len(spaces) == 0: if len(spaces) == 0:
raise ValueError("No space for analysis specified.") raise ValueError("No space for analysis specified.")
...@@ -117,8 +114,7 @@ def power_analyze(field, spaces=None, binbounds=None, ...@@ -117,8 +114,7 @@ def power_analyze(field, spaces=None, binbounds=None,
def power_synthesize_nonrandom(field, spaces=None): def power_synthesize_nonrandom(field, spaces=None):
spaces = range(len(field.domain)) if spaces is None \ spaces = utilities.parse_spaces(spaces, len(field.domain))
else utilities.cast_iseq_to_tuple(spaces)
result_domain = list(field.domain) result_domain = list(field.domain)
spec = sqrt(field) spec = sqrt(field)
......
...@@ -65,12 +65,27 @@ def get_slice_list(shape, axes): ...@@ -65,12 +65,27 @@ def get_slice_list(shape, axes):
yield [slice(None, None)] yield [slice(None, None)]
def cast_iseq_to_tuple(seq): def safe_cast(tfunc, val):
if seq is None: tmp = tfunc(val)
return None if val != tmp:
if np.isscalar(seq): raise ValueError("value changed during cast")
return (int(seq),) return tmp
return tuple(int(item) for item in seq)
def parse_spaces(spaces, maxidx):
maxidx = safe_cast(int, maxidx)
if spaces is None:
return tuple(range(maxidx))
elif np.isscalar(spaces):
spaces = (safe_cast(int, spaces),)
else:
spaces = tuple(safe_cast(int, item) for item in spaces)
tmp = tuple(set(spaces))
if tmp[0] < 0 or tmp[-1] >= maxidx:
raise ValueError("space index out of range")
if len(tmp) != len(spaces):
raise ValueError("multiply defined space indices")
return spaces
def infer_space(domain, space): def infer_space(domain, space):
......
...@@ -129,3 +129,11 @@ class Test_Functionality(unittest.TestCase): ...@@ -129,3 +129,11 @@ class Test_Functionality(unittest.TestCase):
f2 = ift.Field.from_random("normal", domain=s, dtype=np.complex128) f2 = ift.Field.from_random("normal", domain=s, dtype=np.complex128)
assert_allclose(f1.vdot(f2), f1.vdot(f2, spaces=0)) assert_allclose(f1.vdot(f2), f1.vdot(f2, spaces=0))
assert_allclose(f1.vdot(f2), np.conj(f2.vdot(f1))) assert_allclose(f1.vdot(f2), np.conj(f2.vdot(f1)))
def test_vdot2(self):
x1 = ift.RGSpace((200,))
x2 = ift.RGSpace((150,))
m = ift.Field((x1, x2), val=.5)
res = m.vdot(m, spaces=1)
assert_allclose(ift.dobj.to_global_data(res.val),
ift.dobj.to_global_data(ift.Field(x1, val=.25).val))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment