Commit 0fd750df authored by Sebastian Hutschenreuter's avatar Sebastian Hutschenreuter
Browse files

fixed sum and integral

parent c3dd0ba2
......@@ -44,11 +44,6 @@ class OuterProduct(LinearOperator):
def __init__(self, field, domain):
if not isinstance(field, Field):
raise TypeError('field needs to be a Nifty Field instance')
if not isinstance(domain, DomainTuple):
raise TypeError('field needs to be a Nifty Field instance')
self._domain = domain
self._field = field
self._target = DomainTuple.make(tuple(sub_d for sub_d in field.domain._dom + domain._dom))
......@@ -60,5 +55,5 @@ class OuterProduct(LinearOperator):
if mode == self.TIMES:
return Field.from_global_data(self._target, np.multiply.outer(self._field.to_global_data(), x.to_global_data()))
axes = len(self._field.shape)
return Field.from_global_data(self._domain, val=np.tensordot(self._field.to_global_data(), x.to_global_data(), axes))
return Field.from_global_data(self._domain, np.tensordot(self._field.to_global_data(), x.to_global_data(), axes))
......@@ -18,6 +18,8 @@
from __future__ import absolute_import, division, print_function
import numpy as np
from ..compat import *
from ..domain_tuple import DomainTuple
from ..domains.unstructured_domain import UnstructuredDomain
......@@ -55,7 +57,6 @@ class SumReductionOperator(LinearOperator):
self._target = DomainTuple.scalar_domain()
else:
self._target = makeDomain(tuple(dom for i, dom in enumerate(self._domain) if not(i in self._spaces)))
self._marg_space = makeDomain(tuple(dom for i, dom in enumerate(self._domain) if (i in self._spaces)))
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
......@@ -68,13 +69,15 @@ class SumReductionOperator(LinearOperator):
if self._spaces is None:
return full(self._domain, x.local_data[()])
else:
for i in self._spaces:
ns = self._domain._dom[i]
# FIXME: nested use of "i"
ps = tuple(i - 1 for i in ns.shape)
dtfi = DomainTupleFieldInserter(domain=self._target, new_space=ns, index=i, position=ps)
x = dtfi(x)
return x*self._marg_space.size
one = np.ones(self._domain.shape)
slice_list = [slice(None), ]*len(self._domain.shape)
p = 0
for i in range(len(self._domain)):
l = len(self._domain[i].shape)
if i in self._spaces:
slice_list[slice(p, p + l)] = (np.newaxis,)*l
p = p + l
return Field.from_global_data(self._domain, x.to_global_data()[tuple(slice_list)]*one)
class IntegralReductionOperator(LinearOperator):
......@@ -108,17 +111,15 @@ class IntegralReductionOperator(LinearOperator):
for d in self._marg_space._dom:
for dis in d.distances:
vol *= dis
if isinstance(self._spaces, int):
sp = (self._spaces, )
else:
sp = self._spaces
for i in sp:
ns = self._domain._dom[i]
# FIXME: nested use of "i"
ps = tuple(i - 1 for i in ns.shape)
dtfi = DomainTupleFieldInserter(domain=self._target, new_space=ns, index=i, position=ps)
x = dtfi(x)
return x*self._marg_space.size*vol
one = np.ones(self._domain.shape)
slice_list = [slice(None), ]*len(self._domain.shape)
p = 0
for i in range(len(self._domain)):
l = len(self._domain[i].shape)
if i in self._spaces:
slice_list[slice(p, p + l)] = (np.newaxis,)*l
p = p + l
return Field.from_global_data(self._domain, x.to_global_data()[tuple(slice_list)]*one*vol)
class ConjugationOperator(EndomorphicOperator):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment