Commit 3c5d3287 authored by Martin Reinecke's avatar Martin Reinecke

use ContractionOperator for most of the work

parent d61dd5f4
......@@ -45,7 +45,7 @@ from .operators.symmetrizing_operator import SymmetrizingOperator
from .operators.block_diagonal_operator import BlockDiagonalOperator
from .operators.outer_product_operator import OuterProduct
from .operators.simple_linear_operators import (
VdotOperator, SumReductionOperator, IntegralReductionOperator, ConjugationOperator, Realizer,
VdotOperator, ConjugationOperator, Realizer,
FieldAdapter, GeometryRemover, NullOperator)
from .operators.energy_operators import (
EnergyOperator, GaussianEnergy, PoissonianEnergy, InverseGammaLikelihood,
......
......@@ -64,8 +64,8 @@ def MfCorrelatedField(s_space_spatial, s_space_energy, amplitude_model_spatial,
pd_energy = PowerDistributor(pd_spatial.domain, p_space_energy, 1)
pd = pd_spatial(pd_energy)
dom_distr_spatial = ContractionOperator(pd.domain, 0).adjoint
dom_distr_energy = ContractionOperator(pd.domain, 1).adjoint
dom_distr_spatial = ContractionOperator(pd.domain, 1).adjoint
dom_distr_energy = ContractionOperator(pd.domain, 0).adjoint
a_spatial = dom_distr_spatial(amplitude_model_spatial)
a_energy = dom_distr_energy(amplitude_model_energy)
......
......@@ -132,12 +132,14 @@ class Linearization(object):
return self.new(
OuterProduct(self._val, other._val.domain)(other._val),
OuterProduct(other._val, self._jac.domain)(self._jac)._myadd(
OuterProduct(self._val, other._jac.domain)(other._jac), False))
OuterProduct(
self._val, other._jac.domain)(other._jac), False))
if np.isscalar(other):
return self.__mul__(other)
if isinstance(other, (Field, MultiField)):
return self.new(OuterProduct(self._val, other._val.domain)(other._val),
OuterProduct(other._val, self._jac.domain)(self._jac))
return self.new(
OuterProduct(self._val, other._val.domain)(other._val),
OuterProduct(other._val, self._jac.domain)(self._jac))
def vdot(self, other):
from .operators.simple_linear_operators import VdotOperator
......@@ -151,26 +153,26 @@ class Linearization(object):
VdotOperator(other._val)(self._jac))
def sum(self, spaces=None):
from .operators.simple_linear_operators import SumReductionOperator
from .operators.contraction_operator import ContractionOperator
if spaces is None:
return self.new(
Field.scalar(self._val.sum()),
SumReductionOperator(self._jac.target, None)(self._jac))
ContractionOperator(self._jac.target, None)(self._jac))
else:
return self.new(
self._val.sum(spaces),
SumReductionOperator(self._jac.target, spaces)(self._jac))
ContractionOperator(self._jac.target, spaces)(self._jac))
def integrate(self, spaces=None):
from .operators.simple_linear_operators import IntegralReductionOperator
from .operators.contraction_operator import ContractionOperator
if spaces is None:
return self.new(
Field.scalar(self._val.integrate()),
IntegralReductionOperator(self._jac.target, None)(self._jac))
ContractionOperator(self._jac.target, None, 1)(self._jac))
else:
return self.new(
self._val.integrate(spaces),
IntegralReductionOperator(self._jac.target, spaces)(self._jac))
ContractionOperator(self._jac.target, spaces, 1)(self._jac))
def exp(self):
tmp = self._val.exp()
......
......@@ -37,28 +37,37 @@ class ContractionOperator(LinearOperator):
----------
domain : Domain, tuple of Domain or DomainTuple
spaces : int or tuple of int
The elements of "domain" which are taken as target.
The elements of "domain" which are contracted.
weight : int, default=0
if nonzero, the fields living on self.domain are weighted with the
specified power.
"""
def __init__(self, domain, spaces):
def __init__(self, domain, spaces, weight=0):
self._domain = DomainTuple.make(domain)
self._spaces = utilities.parse_spaces(spaces, len(self._domain))
self._target = [
dom for i, dom in enumerate(self._domain) if i in self._spaces
dom for i, dom in enumerate(self._domain) if i not in self._spaces
]
self._target = DomainTuple.make(self._target)
self._weight = weight
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.ADJOINT_TIMES:
ldat = x.local_data if 0 in self._spaces else x.to_global_data()
ldat = x.to_global_data() if 0 in self._spaces else x.local_data
shp = []
for i, dom in enumerate(self._domain):
tmp = dom.shape if i > 0 else dom.local_shape
shp += tmp if i in self._spaces else (1,)*len(dom.shape)
shp += tmp if i not in self._spaces else (1,)*len(dom.shape)
ldat = np.broadcast_to(ldat.reshape(shp), self._domain.local_shape)
return Field.from_local_data(self._domain, ldat)
res = Field.from_local_data(self._domain, ldat)
if self._weight != 0:
res = res.weight(self._weight, spaces=self._spaces)
return res
else:
return x.sum(
[s for s in range(len(x.domain)) if s not in self._spaces])
if self._weight != 0:
x = x.weight(self._weight, spaces=self._spaces)
res = x.sum(self._spaces)
return res if isinstance(res, Field) else Field.scalar(res)
......@@ -46,14 +46,19 @@ class OuterProduct(LinearOperator):
self._domain = domain
self._field = field
self._target = DomainTuple.make(tuple(sub_d for sub_d in field.domain._dom + domain._dom))
self._target = DomainTuple.make(
tuple(sub_d for sub_d in field.domain._dom + domain._dom))
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return Field.from_global_data(self._target, np.multiply.outer(self._field.to_global_data(), x.to_global_data()))
return Field.from_global_data(
self._target, np.multiply.outer(
self._field.to_global_data(), x.to_global_data()))
axes = len(self._field.shape)
return Field.from_global_data(self._domain, np.tensordot(self._field.to_global_data(), x.to_global_data(), axes))
return Field.from_global_data(
self._domain, np.tensordot(
self._field.to_global_data(), x.to_global_data(), axes))
......@@ -18,19 +18,15 @@
from __future__ import absolute_import, division, print_function
import numpy as np
from ..compat import *
from ..domain_tuple import DomainTuple
from ..domains.unstructured_domain import UnstructuredDomain
from ..field import Field
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from ..sugar import full, makeDomain
from ..sugar import full
from .endomorphic_operator import EndomorphicOperator
from .linear_operator import LinearOperator
from .domain_tuple_field_inserter import DomainTupleFieldInserter
from .. import utilities
class VdotOperator(LinearOperator):
......@@ -47,81 +43,6 @@ class VdotOperator(LinearOperator):
return self._field*x.local_data[()]
class SumReductionOperator(LinearOperator):
def __init__(self, domain, spaces=None):
self._domain = DomainTuple.make(domain)
self._spaces = utilities.parse_spaces(spaces, len(self._domain))
if len(self._spaces) == len(self._domain):
self._spaces = None
if self._spaces is None:
self._target = DomainTuple.scalar_domain()
else:
self._target = makeDomain(tuple(dom for i, dom in enumerate(self._domain) if not(i in self._spaces)))
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
if self._spaces is None:
return Field.scalar(x.sum())
else:
return x.sum(self._spaces)
if self._spaces is None:
return full(self._domain, x.local_data[()])
else:
one = np.ones(self._domain.shape)
slice_list = [slice(None), ]*len(self._domain.shape)
p = 0
for i in range(len(self._domain)):
l = len(self._domain[i].shape)
if i in self._spaces:
slice_list[slice(p, p + l)] = (np.newaxis,)*l
p = p + l
return Field.from_global_data(self._domain, x.to_global_data()[tuple(slice_list)]*one)
class IntegralReductionOperator(LinearOperator):
def __init__(self, domain, spaces=None):
self._domain = DomainTuple.make(domain)
self._spaces = utilities.parse_spaces(spaces, len(self._domain))
if len(self._spaces) == len(self._domain):
self._spaces = None
if self._spaces is None:
self._target = DomainTuple.scalar_domain()
else:
self._target = makeDomain(tuple(dom for i, dom in enumerate(self._domain) if not(i in self._spaces)))
self._marg_space = makeDomain(tuple(dom for i, dom in enumerate(self._domain) if (i in self._spaces)))
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
vol = 1.
if mode == self.TIMES:
if self._spaces is None:
return Field.scalar(x.integrate())
else:
return x.integrate(self._spaces)
if self._spaces is None:
for d in self._domain._dom:
for dis in d.distances:
vol *= dis
return full(self._domain, x.local_data[()]*vol)
else:
for d in self._marg_space._dom:
for dis in d.distances:
vol *= dis
one = np.ones(self._domain.shape)
slice_list = [slice(None), ]*len(self._domain.shape)
p = 0
for i in range(len(self._domain)):
l = len(self._domain[i].shape)
if i in self._spaces:
slice_list[slice(p, p + l)] = (np.newaxis,)*l
p = p + l
return Field.from_global_data(self._domain, x.to_global_data()[tuple(slice_list)]*one*vol)
class ConjugationOperator(EndomorphicOperator):
def __init__(self, domain):
self._domain = DomainTuple.make(domain)
......
......@@ -65,12 +65,6 @@ class Consistency_Tests(unittest.TestCase):
dtype=dtype))
ift.extra.consistency_check(op, dtype, dtype)
@expand(product(_h_spaces + _p_spaces + _pow_spaces,
[np.float64, np.complex128]))
def testSumReductionOperator(self, sp, dtype):
op = ift.SumReductionOperator(sp)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product([(ift.RGSpace(10, harmonic=True), 4, 0),
(ift.RGSpace((24, 31), distances=(0.4, 2.34),
harmonic=True), 3, 0),
......@@ -193,11 +187,11 @@ class Consistency_Tests(unittest.TestCase):
ift.extra.consistency_check(op, dtype, dtype)
@expand(product([0, 1, 2, 3, (0, 1), (0, 2), (0, 1, 2), (0, 2, 3), (1, 3)],
[np.float64, np.complex128]))
def testContractionOperator(self, spaces, dtype):
dom = (ift.RGSpace(10), ift.UnstructuredDomain(13), ift.GLSpace(5),
[0, 1, 2, -1], [np.float64, np.complex128]))
def testContractionOperator(self, spaces, wgt, dtype):
dom = (ift.RGSpace(10), ift.RGSpace(13), ift.GLSpace(5),
ift.HPSpace(4))
op = ift.ContractionOperator(dom, spaces)
op = ift.ContractionOperator(dom, spaces, wgt)
ift.extra.consistency_check(op, dtype, dtype)
def testDomainTupleFieldInserter(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment