Commit d659da88 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'integration_operator' into 'NIFTy_6'

Integration operator

See merge request !464
parents 4d8c1460 31932e59
Pipeline #75205 passed with stages
in 8 minutes and 22 seconds
......@@ -26,7 +26,7 @@ from .operators.diagonal_operator import DiagonalOperator
from .operators.distributors import DOFDistributor, PowerDistributor
from .operators.domain_tuple_field_inserter import DomainTupleFieldInserter
from .operators.einsum import LinearEinsum, MultiLinearEinsum
from .operators.contraction_operator import ContractionOperator
from .operators.contraction_operator import ContractionOperator, IntegrationOperator
from .operators.linear_interpolation import LinearInterpolator
from .operators.endomorphic_operator import EndomorphicOperator
from .operators.harmonic_operators import (
......
......@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
......@@ -270,10 +270,8 @@ class Linearization(Operator):
Linearization
the (partial) integral
"""
from .operators.contraction_operator import ContractionOperator
return self.new(
self._val.integrate(spaces),
ContractionOperator(self._jac.target, spaces, 1)(self._jac))
from .operators.contraction_operator import IntegrationOperator
return IntegrationOperator(self._target, spaces)(self)
def ptw(self, op, *args, **kwargs):
t1, t2 = self._val.ptw_with_deriv(op, *args, **kwargs)
......
......@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Copyright(C) 2013-2020 Max-Planck-Society
import numpy as np
import scipy.sparse.linalg as ssl
......
......@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
......@@ -27,7 +27,7 @@ class ContractionOperator(LinearOperator):
"""A :class:`LinearOperator` which sums up fields into the direction of
subspaces.
This Operator sums up a field with is defined on a :class:`DomainTuple`
This Operator sums up a field which is defined on a :class:`DomainTuple`
to a :class:`DomainTuple` which is a subset of the former.
Parameters
......@@ -36,19 +36,19 @@ class ContractionOperator(LinearOperator):
spaces : None, int or tuple of int
The elements of "domain" which are contracted.
If `None`, everything is contracted
weight : int, default=0
power : int, default=0
If nonzero, the fields defined on self.domain are weighted with the
specified power along the submdomains which are contracted.
"""
def __init__(self, domain, spaces, weight=0):
def __init__(self, domain, spaces, power=0):
self._domain = DomainTuple.make(domain)
self._spaces = utilities.parse_spaces(spaces, len(self._domain))
self._target = [
dom for i, dom in enumerate(self._domain) if i not in self._spaces
]
self._target = DomainTuple.make(self._target)
self._weight = weight
self._power = power
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
......@@ -61,11 +61,28 @@ class ContractionOperator(LinearOperator):
shp += tmp if i not in self._spaces else (1,)*len(dom.shape)
ldat = np.broadcast_to(ldat.reshape(shp), self._domain.shape)
res = Field(self._domain, ldat)
if self._weight != 0:
res = res.weight(self._weight, spaces=self._spaces)
if self._power != 0:
res = res.weight(self._power, spaces=self._spaces)
return res
else:
if self._weight != 0:
x = x.weight(self._weight, spaces=self._spaces)
if self._power != 0:
x = x.weight(self._power, spaces=self._spaces)
res = x.sum(self._spaces)
return res if isinstance(res, Field) else Field.scalar(res)
def IntegrationOperator(domain, spaces):
"""A :class:`LinearOperator` which integrates fields into the direction
of subspaces.
This Operator integrates a field which is defined on a :class:`DomainTuple`
to a :class:`DomainTuple` which is a subset of the former.
Parameters
----------
domain : Domain, tuple of Domain or DomainTuple
spaces : None, int or tuple of int
The elements of "domain" which are contracted.
If `None`, everything is contracted
"""
return ContractionOperator(domain, spaces, 1)
......@@ -128,6 +128,10 @@ class Operator(metaclass=NiftyMeta):
from .contraction_operator import ContractionOperator
return ContractionOperator(self.target, spaces)(self)
def integrate(self, spaces=None):
from .contraction_operator import IntegrationOperator
return IntegrationOperator(self.target, spaces)(self)
def vdot(self, other):
from ..sugar import makeOp
if not isinstance(other, Operator):
......
......@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
......@@ -140,11 +140,7 @@ def test_outer():
def test_sum():
x1 = ift.RGSpace((9,), distances=2.)
x2 = ift.RGSpace(
(
2,
12,
), distances=(0.3,))
x2 = ift.RGSpace((2, 12), distances=(0.3,))
m1 = ift.Field(ift.makeDomain(x1), np.arange(9))
m2 = ift.Field.full(ift.makeDomain((x1, x2)), 0.45)
res1 = m1.s_sum()
......@@ -162,11 +158,16 @@ def test_integrate():
res2 = m2.integrate(spaces=1)
assert_allclose(res1, 36*2)
assert_allclose(res2.val, np.full(9, 2*12*0.45*0.3**2))
for m in [m1, m2]:
res3 = m.integrate()
res4 = m.s_integrate()
assert_allclose(res3.val, res4)
dom = ift.HPSpace(3)
assert_allclose(ift.full(dom, 1).s_integrate(), 4*np.pi)
def test_dataconv():
s1 = ift.RGSpace((10,))
ld = np.arange(s1.shape[0])
gd = np.arange(s1.shape[0])
assert_equal(gd, ift.makeField(s1, gd).val)
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from numpy.testing import assert_allclose
import nifty6 as ift
from ..common import setup_function, teardown_function
def test_integration_operator():
x1 = ift.RGSpace((9,), distances=2.)
x2 = ift.RGSpace((2, 12), distances=(0.3,))
dom1 = ift.makeDomain(x1)
dom2 = ift.makeDomain((x1, x2))
f1 = ift.from_random('normal', dom1)
f2 = ift.from_random('normal', dom2)
op1 = ift.ScalingOperator(dom1, 1).integrate()
op2 = ift.ScalingOperator(dom2, 1).integrate()
op3 = ift.ScalingOperator(dom2, 1).integrate(spaces=1)
res1 = f1.integrate()
res2 = op1(f1)
assert_allclose(res1.val, res2.val)
res3 = f2.integrate()
res4 = op2(f2)
assert_allclose(res3.val, res4.val)
res5 = f2.integrate(spaces=1)
res6 = op3(f2)
assert_allclose(res5.val, res6.val)
for op in [op1, op2, op3]:
ift.extra.consistency_check(op, domain_dtype=np.float64,
target_dtype=np.float64)
ift.extra.consistency_check(op, domain_dtype=np.complex128,
target_dtype=np.complex128)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment