Commit d659da88 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'integration_operator' into 'NIFTy_6'

Integration operator

See merge request !464
parents 4d8c1460 31932e59
Pipeline #75205 passed with stages
in 8 minutes and 22 seconds
...@@ -26,7 +26,7 @@ from .operators.diagonal_operator import DiagonalOperator ...@@ -26,7 +26,7 @@ from .operators.diagonal_operator import DiagonalOperator
from .operators.distributors import DOFDistributor, PowerDistributor from .operators.distributors import DOFDistributor, PowerDistributor
from .operators.domain_tuple_field_inserter import DomainTupleFieldInserter from .operators.domain_tuple_field_inserter import DomainTupleFieldInserter
from .operators.einsum import LinearEinsum, MultiLinearEinsum from .operators.einsum import LinearEinsum, MultiLinearEinsum
from .operators.contraction_operator import ContractionOperator from .operators.contraction_operator import ContractionOperator, IntegrationOperator
from .operators.linear_interpolation import LinearInterpolator from .operators.linear_interpolation import LinearInterpolator
from .operators.endomorphic_operator import EndomorphicOperator from .operators.endomorphic_operator import EndomorphicOperator
from .operators.harmonic_operators import ( from .operators.harmonic_operators import (
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
# #
# Copyright(C) 2013-2019 Max-Planck-Society # Copyright(C) 2013-2020 Max-Planck-Society
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
...@@ -270,10 +270,8 @@ class Linearization(Operator): ...@@ -270,10 +270,8 @@ class Linearization(Operator):
Linearization Linearization
the (partial) integral the (partial) integral
""" """
from .operators.contraction_operator import ContractionOperator from .operators.contraction_operator import IntegrationOperator
return self.new( return IntegrationOperator(self._target, spaces)(self)
self._val.integrate(spaces),
ContractionOperator(self._jac.target, spaces, 1)(self._jac))
def ptw(self, op, *args, **kwargs): def ptw(self, op, *args, **kwargs):
t1, t2 = self._val.ptw_with_deriv(op, *args, **kwargs) t1, t2 = self._val.ptw_with_deriv(op, *args, **kwargs)
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
# #
# Copyright(C) 2013-2019 Max-Planck-Society # Copyright(C) 2013-2020 Max-Planck-Society
import numpy as np import numpy as np
import scipy.sparse.linalg as ssl import scipy.sparse.linalg as ssl
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
# #
# Copyright(C) 2013-2019 Max-Planck-Society # Copyright(C) 2013-2020 Max-Planck-Society
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
...@@ -27,7 +27,7 @@ class ContractionOperator(LinearOperator): ...@@ -27,7 +27,7 @@ class ContractionOperator(LinearOperator):
"""A :class:`LinearOperator` which sums up fields into the direction of """A :class:`LinearOperator` which sums up fields into the direction of
subspaces. subspaces.
This Operator sums up a field with is defined on a :class:`DomainTuple` This Operator sums up a field which is defined on a :class:`DomainTuple`
to a :class:`DomainTuple` which is a subset of the former. to a :class:`DomainTuple` which is a subset of the former.
Parameters Parameters
...@@ -36,19 +36,19 @@ class ContractionOperator(LinearOperator): ...@@ -36,19 +36,19 @@ class ContractionOperator(LinearOperator):
spaces : None, int or tuple of int spaces : None, int or tuple of int
The elements of "domain" which are contracted. The elements of "domain" which are contracted.
If `None`, everything is contracted If `None`, everything is contracted
weight : int, default=0 power : int, default=0
If nonzero, the fields defined on self.domain are weighted with the If nonzero, the fields defined on self.domain are weighted with the
specified power along the submdomains which are contracted. specified power along the submdomains which are contracted.
""" """
def __init__(self, domain, spaces, weight=0): def __init__(self, domain, spaces, power=0):
self._domain = DomainTuple.make(domain) self._domain = DomainTuple.make(domain)
self._spaces = utilities.parse_spaces(spaces, len(self._domain)) self._spaces = utilities.parse_spaces(spaces, len(self._domain))
self._target = [ self._target = [
dom for i, dom in enumerate(self._domain) if i not in self._spaces dom for i, dom in enumerate(self._domain) if i not in self._spaces
] ]
self._target = DomainTuple.make(self._target) self._target = DomainTuple.make(self._target)
self._weight = weight self._power = power
self._capability = self.TIMES | self.ADJOINT_TIMES self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode): def apply(self, x, mode):
...@@ -61,11 +61,28 @@ class ContractionOperator(LinearOperator): ...@@ -61,11 +61,28 @@ class ContractionOperator(LinearOperator):
shp += tmp if i not in self._spaces else (1,)*len(dom.shape) shp += tmp if i not in self._spaces else (1,)*len(dom.shape)
ldat = np.broadcast_to(ldat.reshape(shp), self._domain.shape) ldat = np.broadcast_to(ldat.reshape(shp), self._domain.shape)
res = Field(self._domain, ldat) res = Field(self._domain, ldat)
if self._weight != 0: if self._power != 0:
res = res.weight(self._weight, spaces=self._spaces) res = res.weight(self._power, spaces=self._spaces)
return res return res
else: else:
if self._weight != 0: if self._power != 0:
x = x.weight(self._weight, spaces=self._spaces) x = x.weight(self._power, spaces=self._spaces)
res = x.sum(self._spaces) res = x.sum(self._spaces)
return res if isinstance(res, Field) else Field.scalar(res) return res if isinstance(res, Field) else Field.scalar(res)
def IntegrationOperator(domain, spaces):
"""A :class:`LinearOperator` which integrates fields into the direction
of subspaces.
This Operator integrates a field which is defined on a :class:`DomainTuple`
to a :class:`DomainTuple` which is a subset of the former.
Parameters
----------
domain : Domain, tuple of Domain or DomainTuple
spaces : None, int or tuple of int
The elements of "domain" which are contracted.
If `None`, everything is contracted
"""
return ContractionOperator(domain, spaces, 1)
...@@ -128,6 +128,10 @@ class Operator(metaclass=NiftyMeta): ...@@ -128,6 +128,10 @@ class Operator(metaclass=NiftyMeta):
from .contraction_operator import ContractionOperator from .contraction_operator import ContractionOperator
return ContractionOperator(self.target, spaces)(self) return ContractionOperator(self.target, spaces)(self)
def integrate(self, spaces=None):
from .contraction_operator import IntegrationOperator
return IntegrationOperator(self.target, spaces)(self)
def vdot(self, other): def vdot(self, other):
from ..sugar import makeOp from ..sugar import makeOp
if not isinstance(other, Operator): if not isinstance(other, Operator):
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
# #
# Copyright(C) 2013-2019 Max-Planck-Society # Copyright(C) 2013-2020 Max-Planck-Society
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
...@@ -140,11 +140,7 @@ def test_outer(): ...@@ -140,11 +140,7 @@ def test_outer():
def test_sum(): def test_sum():
x1 = ift.RGSpace((9,), distances=2.) x1 = ift.RGSpace((9,), distances=2.)
x2 = ift.RGSpace( x2 = ift.RGSpace((2, 12), distances=(0.3,))
(
2,
12,
), distances=(0.3,))
m1 = ift.Field(ift.makeDomain(x1), np.arange(9)) m1 = ift.Field(ift.makeDomain(x1), np.arange(9))
m2 = ift.Field.full(ift.makeDomain((x1, x2)), 0.45) m2 = ift.Field.full(ift.makeDomain((x1, x2)), 0.45)
res1 = m1.s_sum() res1 = m1.s_sum()
...@@ -162,11 +158,16 @@ def test_integrate(): ...@@ -162,11 +158,16 @@ def test_integrate():
res2 = m2.integrate(spaces=1) res2 = m2.integrate(spaces=1)
assert_allclose(res1, 36*2) assert_allclose(res1, 36*2)
assert_allclose(res2.val, np.full(9, 2*12*0.45*0.3**2)) assert_allclose(res2.val, np.full(9, 2*12*0.45*0.3**2))
for m in [m1, m2]:
res3 = m.integrate()
res4 = m.s_integrate()
assert_allclose(res3.val, res4)
dom = ift.HPSpace(3)
assert_allclose(ift.full(dom, 1).s_integrate(), 4*np.pi)
def test_dataconv(): def test_dataconv():
s1 = ift.RGSpace((10,)) s1 = ift.RGSpace((10,))
ld = np.arange(s1.shape[0])
gd = np.arange(s1.shape[0]) gd = np.arange(s1.shape[0])
assert_equal(gd, ift.makeField(s1, gd).val) assert_equal(gd, ift.makeField(s1, gd).val)
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from numpy.testing import assert_allclose
import nifty6 as ift
from ..common import setup_function, teardown_function
def test_integration_operator():
x1 = ift.RGSpace((9,), distances=2.)
x2 = ift.RGSpace((2, 12), distances=(0.3,))
dom1 = ift.makeDomain(x1)
dom2 = ift.makeDomain((x1, x2))
f1 = ift.from_random('normal', dom1)
f2 = ift.from_random('normal', dom2)
op1 = ift.ScalingOperator(dom1, 1).integrate()
op2 = ift.ScalingOperator(dom2, 1).integrate()
op3 = ift.ScalingOperator(dom2, 1).integrate(spaces=1)
res1 = f1.integrate()
res2 = op1(f1)
assert_allclose(res1.val, res2.val)
res3 = f2.integrate()
res4 = op2(f2)
assert_allclose(res3.val, res4.val)
res5 = f2.integrate(spaces=1)
res6 = op3(f2)
assert_allclose(res5.val, res6.val)
for op in [op1, op2, op3]:
ift.extra.consistency_check(op, domain_dtype=np.float64,
target_dtype=np.float64)
ift.extra.consistency_check(op, domain_dtype=np.complex128,
target_dtype=np.complex128)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment