Commit c91ee857 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'outer_product' into 'NIFTy_5'

Outer product

See merge request ift/nifty-dev!106
parents 71cc7162 9986da5c
......@@ -43,8 +43,9 @@ from .operators.slope_operator import SlopeOperator
from .operators.smoothness_operator import SmoothnessOperator
from .operators.symmetrizing_operator import SymmetrizingOperator
from .operators.block_diagonal_operator import BlockDiagonalOperator
from .operators.outer_product_operator import OuterProduct
from .operators.simple_linear_operators import (
VdotOperator, SumReductionOperator, ConjugationOperator, Realizer,
VdotOperator, ConjugationOperator, Realizer,
FieldAdapter, GeometryRemover, NullOperator)
from .operators.energy_operators import (
EnergyOperator, GaussianEnergy, PoissonianEnergy, InverseGammaLikelihood,
......
......@@ -327,6 +327,23 @@ class Field(object):
return Field.from_local_data(self._domain, aout)
def outer(self, x):
""" Computes the outer product of 'self' with x.
Parameters
----------
x : Field
Returns
----------
Field, lives on the product space of self.domain and x.domain
"""
if not isinstance(x, Field):
raise TypeError("The multiplier must be an instance of " +
"the NIFTy field class")
from .operators.outer_product_operator import OuterProduct
return OuterProduct(self, x.domain)(x)
def vdot(self, x=None, spaces=None):
""" Computes the dot product of 'self' with x.
......
......@@ -67,8 +67,8 @@ def MfCorrelatedField(s_space_spatial, s_space_energy, amplitude_model_spatial,
pd_energy = PowerDistributor(pd_spatial.domain, p_space_energy, 1)
pd = pd_spatial(pd_energy)
dom_distr_spatial = ContractionOperator(pd.domain, 0).adjoint
dom_distr_energy = ContractionOperator(pd.domain, 1).adjoint
dom_distr_spatial = ContractionOperator(pd.domain, 1).adjoint
dom_distr_energy = ContractionOperator(pd.domain, 0).adjoint
a_spatial = dom_distr_spatial(amplitude_model_spatial)
a_energy = dom_distr_energy(amplitude_model_energy)
......
......@@ -126,6 +126,19 @@ class Linearization(object):
def __rmul__(self, other):
return self.__mul__(other)
def outer(self, other):
from .operators.outer_product_operator import OuterProduct
if isinstance(other, Linearization):
return self.new(
OuterProduct(self._val, other.target)(other._val),
OuterProduct(self._jac(self._val), other.target)._myadd(
OuterProduct(self._val, other.target)(other._jac), False))
if np.isscalar(other):
return self.__mul__(other)
if isinstance(other, (Field, MultiField)):
return self.new(OuterProduct(self._val, other.domain)(other),
OuterProduct(self._jac(self._val), other.domain))
def vdot(self, other):
from .operators.simple_linear_operators import VdotOperator
if isinstance(other, (Field, MultiField)):
......@@ -137,11 +150,27 @@ class Linearization(object):
VdotOperator(self._val)(other._jac) +
VdotOperator(other._val)(self._jac))
def sum(self):
from .operators.simple_linear_operators import SumReductionOperator
return self.new(
Field.scalar(self._val.sum()),
SumReductionOperator(self._jac.target)(self._jac))
def sum(self, spaces=None):
from .operators.contraction_operator import ContractionOperator
if spaces is None:
return self.new(
Field.scalar(self._val.sum()),
ContractionOperator(self._jac.target, None)(self._jac))
else:
return self.new(
self._val.sum(spaces),
ContractionOperator(self._jac.target, spaces)(self._jac))
def integrate(self, spaces=None):
from .operators.contraction_operator import ContractionOperator
if spaces is None:
return self.new(
Field.scalar(self._val.integrate()),
ContractionOperator(self._jac.target, None, 1)(self._jac))
else:
return self.new(
self._val.integrate(spaces),
ContractionOperator(self._jac.target, spaces, 1)(self._jac))
def exp(self):
tmp = self._val.exp()
......
......@@ -37,28 +37,37 @@ class ContractionOperator(LinearOperator):
----------
domain : Domain, tuple of Domain or DomainTuple
spaces : int or tuple of int
The elements of "domain" which are taken as target.
The elements of "domain" which are contracted.
weight : int, default=0
if nonzero, the fields living on self.domain are weighted with the
specified power.
"""
def __init__(self, domain, spaces):
def __init__(self, domain, spaces, weight=0):
self._domain = DomainTuple.make(domain)
self._spaces = utilities.parse_spaces(spaces, len(self._domain))
self._target = [
dom for i, dom in enumerate(self._domain) if i in self._spaces
dom for i, dom in enumerate(self._domain) if i not in self._spaces
]
self._target = DomainTuple.make(self._target)
self._weight = weight
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.ADJOINT_TIMES:
ldat = x.local_data if 0 in self._spaces else x.to_global_data()
ldat = x.to_global_data() if 0 in self._spaces else x.local_data
shp = []
for i, dom in enumerate(self._domain):
tmp = dom.shape if i > 0 else dom.local_shape
shp += tmp if i in self._spaces else (1,)*len(dom.shape)
shp += tmp if i not in self._spaces else (1,)*len(dom.shape)
ldat = np.broadcast_to(ldat.reshape(shp), self._domain.local_shape)
return Field.from_local_data(self._domain, ldat)
res = Field.from_local_data(self._domain, ldat)
if self._weight != 0:
res = res.weight(self._weight, spaces=self._spaces)
return res
else:
return x.sum(
[s for s in range(len(x.domain)) if s not in self._spaces])
if self._weight != 0:
x = x.weight(self._weight, spaces=self._spaces)
res = x.sum(self._spaces)
return res if isinstance(res, Field) else Field.scalar(res)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
import itertools
import numpy as np
from .. import dobj, utilities
from ..compat import *
from ..domain_tuple import DomainTuple
from ..domains.rg_space import RGSpace
from ..multi_field import MultiField, MultiDomain
from ..field import Field
from .linear_operator import LinearOperator
import operator
class OuterProduct(LinearOperator):
"""Performs the pointwise outer product of two fields.
Parameters
---------
field: Field,
domain: DomainTuple, the domain of the input field
---------
"""
def __init__(self, field, domain):
self._domain = domain
self._field = field
self._target = DomainTuple.make(
tuple(sub_d for sub_d in field.domain._dom + domain._dom))
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return Field.from_global_data(
self._target, np.multiply.outer(
self._field.to_global_data(), x.to_global_data()))
axes = len(self._field.shape)
return Field.from_global_data(
self._domain, np.tensordot(
self._field.to_global_data(), x.to_global_data(), axes))
......@@ -43,19 +43,6 @@ class VdotOperator(LinearOperator):
return self._field*x.local_data[()]
class SumReductionOperator(LinearOperator):
def __init__(self, domain):
self._domain = DomainTuple.make(domain)
self._target = DomainTuple.scalar_domain()
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return Field.scalar(x.sum())
return full(self._domain, x.local_data[()])
class ConjugationOperator(EndomorphicOperator):
def __init__(self, domain):
self._domain = DomainTuple.make(domain)
......
......@@ -136,6 +136,34 @@ class Test_Functionality(unittest.TestCase):
res = m.vdot(m, spaces=1)
assert_allclose(res.local_data, 37.5)
def test_outer(self):
x1 = ift.RGSpace((9,))
x2 = ift.RGSpace((3,))
m1 = ift.Field.full(x1, .5)
m2 = ift.Field.full(x2, 3.)
res = m1.outer(m2)
assert_allclose(res.to_global_data(), np.full((9, 3,), 1.5))
def test_sum(self):
x1 = ift.RGSpace((9,), distances=2.)
x2 = ift.RGSpace((2, 12,), distances=(0.3,))
m1 = ift.Field.from_global_data(ift.makeDomain(x1), np.arange(9))
m2 = ift.Field.full(ift.makeDomain((x1, x2)), 0.45)
res1 = m1.sum()
res2 = m2.sum(spaces=1)
assert_allclose(res1, 36)
assert_allclose(res2.to_global_data(), np.full(9, 2*12*0.45))
def test_integrate(self):
x1 = ift.RGSpace((9,), distances=2.)
x2 = ift.RGSpace((2, 12,), distances=(0.3,))
m1 = ift.Field.from_global_data(ift.makeDomain(x1), np.arange(9))
m2 = ift.Field.full(ift.makeDomain((x1, x2)), 0.45)
res1 = m1.integrate()
res2 = m2.integrate(spaces=1)
assert_allclose(res1, 36*2)
assert_allclose(res2.to_global_data(), np.full(9, 2*12*0.45*0.3**2))
def test_dataconv(self):
s1 = ift.RGSpace((10,))
ld = np.arange(ift.dobj.local_shape(s1.shape)[0])
......
......@@ -64,26 +64,29 @@ class Model_Tests(unittest.TestCase):
dom = ift.MultiDomain.union((dom1, dom2))
model = ift.FieldAdapter(dom, "s1")*ift.FieldAdapter(dom, "s2")
pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
model = ift.FieldAdapter(dom, "s1")+ift.FieldAdapter(dom, "s2")
pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
model = ift.FieldAdapter(dom, "s1").scale(3.)
pos = ift.from_random("normal", dom1)
ift.extra.check_value_gradient_consistency(model, pos)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
model = ift.ScalingOperator(2.456, space)(
ift.FieldAdapter(dom, "s1")*ift.FieldAdapter(dom, "s2"))
pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
model = ift.positive_tanh(ift.ScalingOperator(2.456, space)(
ift.FieldAdapter(dom, "s1")*ift.FieldAdapter(dom, "s2")))
pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
pos = ift.from_random("normal", dom)
model = ift.OuterProduct(pos['s1'], ift.makeDomain(space))
ift.extra.check_value_gradient_consistency(model, pos['s2'], ntries=20)
if isinstance(space, ift.RGSpace):
model = ift.FFTOperator(space)(
ift.FieldAdapter(dom, "s1")*ift.FieldAdapter(dom, "s2"))
pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
@expand(product(
[ift.GLSpace(15),
......@@ -106,12 +109,12 @@ class Model_Tests(unittest.TestCase):
sv, im, iv)
S = ift.ScalingOperator(1., model.domain)
pos = S.draw_sample()
ift.extra.check_value_gradient_consistency(model, pos)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
model2 = ift.CorrelatedField(space, model)
S = ift.ScalingOperator(1., model2.domain)
pos = S.draw_sample()
ift.extra.check_value_gradient_consistency(model2, pos)
ift.extra.check_value_gradient_consistency(model2, pos, ntries=20)
@expand(product(
[ift.GLSpace(15),
......@@ -125,7 +128,8 @@ class Model_Tests(unittest.TestCase):
q = 0.73
model = ift.InverseGammaModel(space, alpha, q)
# FIXME All those cdfs and ppfs are not very accurate
ift.extra.check_value_gradient_consistency(model, pos, tol=1e-2)
ift.extra.check_value_gradient_consistency(model, pos, tol=1e-2,
ntries=20)
# @expand(product(
# ['Variable', 'Constant'],
......
......@@ -65,12 +65,6 @@ class Consistency_Tests(unittest.TestCase):
dtype=dtype))
ift.extra.consistency_check(op, dtype, dtype)
@expand(product(_h_spaces + _p_spaces + _pow_spaces,
[np.float64, np.complex128]))
def testSumReductionOperator(self, sp, dtype):
op = ift.SumReductionOperator(sp)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product([(ift.RGSpace(10, harmonic=True), 4, 0),
(ift.RGSpace((24, 31), distances=(0.4, 2.34),
harmonic=True), 3, 0),
......@@ -193,11 +187,11 @@ class Consistency_Tests(unittest.TestCase):
ift.extra.consistency_check(op, dtype, dtype)
@expand(product([0, 1, 2, 3, (0, 1), (0, 2), (0, 1, 2), (0, 2, 3), (1, 3)],
[np.float64, np.complex128]))
def testContractionOperator(self, spaces, dtype):
dom = (ift.RGSpace(10), ift.UnstructuredDomain(13), ift.GLSpace(5),
[0, 1, 2, -1], [np.float64, np.complex128]))
def testContractionOperator(self, spaces, wgt, dtype):
dom = (ift.RGSpace(10), ift.RGSpace(13), ift.GLSpace(5),
ift.HPSpace(4))
op = ift.ContractionOperator(dom, spaces)
op = ift.ContractionOperator(dom, spaces, wgt)
ift.extra.consistency_check(op, dtype, dtype)
def testDomainTupleFieldInserter(self):
......@@ -263,3 +257,17 @@ class Consistency_Tests(unittest.TestCase):
def testRegridding(self, domain, shape, space):
op = ift.RegriddingOperator(domain, shape, space)
ift.extra.consistency_check(op)
@expand(product([ift.DomainTuple.make((ift.RGSpace((3, 5, 4)),
ift.RGSpace((16,),
distances=(7.,))),),
ift.DomainTuple.make(ift.HPSpace(12),)],
[ift.DomainTuple.make((ift.RGSpace((2,)),
ift.GLSpace(10)),),
ift.DomainTuple.make(ift.RGSpace((10, 12),
distances=(0.1, 1.)),)]
))
def testOuter(self, fdomain, domain):
f = ift.from_random('normal', fdomain)
op = ift.OuterProduct(f, domain)
ift.extra.consistency_check(op)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment