Commit 11d6a263 authored by Sebastian Hutschenreuter's avatar Sebastian Hutschenreuter
Browse files

added outer product operator, added outer method to field and linearisation,...

added outer product operator, added outer method to field and linearisation, added integrate method to linearisation,
parent b55598c9
......@@ -43,8 +43,9 @@ from .operators.slope_operator import SlopeOperator
from .operators.smoothness_operator import SmoothnessOperator
from .operators.symmetrizing_operator import SymmetrizingOperator
from .operators.block_diagonal_operator import BlockDiagonalOperator
from .operators.outer_product_operator import OuterProduct
from .operators.simple_linear_operators import (
VdotOperator, SumReductionOperator, ConjugationOperator, Realizer,
VdotOperator, SumReductionOperator, IntegralReductionOperator, ConjugationOperator, Realizer,
FieldAdapter, GeometryRemover, NullOperator)
from .operators.energy_operators import (
EnergyOperator, GaussianEnergy, PoissonianEnergy, InverseGammaLikelihood,
......
......@@ -327,6 +327,23 @@ class Field(object):
return Field.from_local_data(self._domain, aout)
def outer(self, x):
""" Computes the outer product of 'self' with x.
Parameters
----------
x : Field
Returns
----------
Field, lives on the product space of self.domain and x.domain
"""
if not isinstance(x, Field):
raise TypeError("The multiplier must be an instance of " +
"the NIFTy field class")
from .operators.outer_product_operator import OuterProduct
return OuterProduct(self, x.domain)(x)
def vdot(self, x=None, spaces=None):
""" Computes the dot product of 'self' with x.
......@@ -460,7 +477,7 @@ class Field(object):
swgt = self.scalar_weight(spaces)
if swgt is not None:
res = self.sum(spaces)
res = res*swgt
res = res * swgt
return res
tmp = self.weight(1, spaces=spaces)
return tmp.sum(spaces)
......
......@@ -93,17 +93,6 @@ class Linearization(object):
def __rsub__(self, other):
return (-self).__add__(other)
def __truediv__(self, other):
if isinstance(other, Linearization):
return self.__mul__(other.inverse())
return self.__mul__(1./other)
def __rtruediv__(self, other):
return self.inverse().__mul__(other)
def inverse(self):
return self.new(1./self._val, makeOp(-1./(self._val**2))(self._jac))
def __mul__(self, other):
from .sugar import makeOp
if isinstance(other, Linearization):
......@@ -126,6 +115,22 @@ class Linearization(object):
def __rmul__(self, other):
return self.__mul__(other)
def outer(self, other):
from .operators.outer_product_operator import OuterProduct
if isinstance(other, Linearization):
return self.new(
OuterProduct(self._val, other._val.domain)(other._val),
OuterProduct(other._val, self._jac.domain)(self._jac)._myadd(
OuterProduct(self._val, other._jac.domain)(other._jac), False))
if np.isscalar(other):
if other == 1:
return self
met = None if self._metric is None else self._metric.scale(other)
return self.new(self._val*other, self._jac.scale(other), met)
if isinstance(other, (Field, MultiField)):
return self.new(OuterProduct(self._val, other._val.domain)(other._val),
OuterProduct(other._val, self._jac.domain)(self._jac))
def vdot(self, other):
from .operators.simple_linear_operators import VdotOperator
if isinstance(other, (Field, MultiField)):
......@@ -137,11 +142,27 @@ class Linearization(object):
VdotOperator(self._val)(other._jac) +
VdotOperator(other._val)(self._jac))
def sum(self):
def sum(self, spaces=None):
from .operators.simple_linear_operators import SumReductionOperator
return self.new(
Field.scalar(self._val.sum()),
SumReductionOperator(self._jac.target)(self._jac))
if spaces is None:
return self.new(
Field.scalar(self._val.sum()),
SumReductionOperator(self._jac.target, None)(self._jac))
else:
return self.new(
self._val.sum(spaces),
SumReductionOperator(self._jac.target, spaces)(self._jac))
def integrate(self, spaces=None):
from .operators.simple_linear_operators import IntegralReductionOperator
if spaces is None:
return self.new(
Field.scalar(self._val.integrate()),
IntegralReductionOperator(self._jac.target, None)(self._jac))
else:
return self.new(
self._val.integrate(spaces),
IntegralReductionOperator(self._jac.target, spaces)(self._jac))
def exp(self):
tmp = self._val.exp()
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
import itertools
import numpy as np
from .. import dobj, utilities
from ..compat import *
from ..domain_tuple import DomainTuple
from ..domains.rg_space import RGSpace
from ..multi_field import MultiField, MultiDomain
from ..field import Field
from .linear_operator import LinearOperator
import operator
class OuterProduct(LinearOperator):
"""Performs the pointwise outer product of two fields.
Parameters
---------
field: Field,
domain: DomainTuple, the domain of the input field
---------
"""
def __init__(self, field, domain):
if not isinstance(field, Field):
raise TypeError('field needs to be a Nifty Field instance')
if not isinstance(domain, DomainTuple):
raise TypeError('field needs to be a Nifty Field instance')
self._domain = domain
self._field = field
self._target = DomainTuple.make(tuple(sub_d for sub_d in field.domain._dom + domain._dom))
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return Field(self._target, np.multiply.outer(self._field.to_global_data(), x.to_global_data()))
axes = len(self._field.shape)
return Field(self._domain, val=np.tensordot(self._field.to_global_data(), x.to_global_data(), axes))
......@@ -24,9 +24,10 @@ from ..domains.unstructured_domain import UnstructuredDomain
from ..field import Field
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from ..sugar import full
from ..sugar import full, makeDomain
from .endomorphic_operator import EndomorphicOperator
from .linear_operator import LinearOperator
from .domain_tuple_field_inserter import DomainTupleFieldInserter
class VdotOperator(LinearOperator):
......@@ -44,16 +45,77 @@ class VdotOperator(LinearOperator):
class SumReductionOperator(LinearOperator):
def __init__(self, domain):
self._domain = DomainTuple.make(domain)
self._target = DomainTuple.scalar_domain()
def __init__(self, domain, spaces=None):
self._spaces = spaces
self._domain = domain
if spaces is None:
self._target = DomainTuple.scalar_domain()
else:
self._target = makeDomain(tuple(dom for i, dom in enumerate(self._domain) if not(i == spaces)))
self._marg_space = makeDomain(tuple(dom for i, dom in enumerate(self._domain) if (i == spaces)))
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return Field.scalar(x.sum())
return full(self._domain, x.local_data[()])
if self._spaces is None:
return Field.scalar(x.sum())
else:
return x.sum(self._spaces)
if self._spaces is None:
return full(self._domain, x.local_data[()])
else:
if isinstance(self._spaces, int):
sp = (self._spaces, )
else:
sp = self._spaces
for i in sp:
ns = self._domain._dom[i]
ps = tuple(i - 1 for i in ns.shape)
dtfi = DomainTupleFieldInserter(domain=self._target, new_space=ns, index=i, position=ps)
x = dtfi(x)
return x*self._marg_space.size
class IntegralReductionOperator(LinearOperator):
def __init__(self, domain, spaces=None):
self._spaces = spaces
self._domain = domain
if spaces is None:
self._target = DomainTuple.scalar_domain()
else:
self._target = makeDomain(tuple(dom for i, dom in enumerate(self._domain) if not(i == spaces)))
self._marg_space = makeDomain(tuple(dom for i, dom in enumerate(self._domain) if (i == spaces)))
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
vol = 1.
if mode == self.TIMES:
if self._spaces is None:
return Field.scalar(x.integrate())
else:
return x.integrate(self._spaces)
if self._spaces is None:
for d in self._domain._dom:
for dis in d.distances:
vol *= dis
return full(self._domain, x.local_data[()]*vol)
else:
for d in self._marg_space._dom:
for dis in d.distances:
vol *= dis
if isinstance(self._spaces, int):
sp = (self._spaces, )
else:
sp = self._spaces
for i in sp:
ns = self._domain._dom[i]
ps = tuple(i - 1 for i in ns.shape)
dtfi = DomainTupleFieldInserter(domain=self._target, new_space=ns, index=i, position=ps)
x = dtfi(x)
return x*self._marg_space.size*vol
class ConjugationOperator(EndomorphicOperator):
......
......@@ -136,6 +136,14 @@ class Test_Functionality(unittest.TestCase):
res = m.vdot(m, spaces=1)
assert_allclose(res.local_data, 37.5)
def test_outer(self):
x1 = ift.RGSpace((9,))
x2 = ift.RGSpace((3,))
m1 = ift.Field.full(x1, .5)
m2 = ift.Field.full(x2, 3.)
res = m1.outer(m2)
assert_allclose(res.local_data, np.full((9, 3,), 1.5))
def test_dataconv(self):
s1 = ift.RGSpace((10,))
ld = np.arange(ift.dobj.local_shape(s1.shape)[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment