Commit 662b29d7 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

rework operators

parent c83661d0
Pipeline #23250 passed with stage
in 4 minutes and 49 seconds
...@@ -2,36 +2,6 @@ import numpy as np ...@@ -2,36 +2,6 @@ import numpy as np
import nifty2go as ift import nifty2go as ift
# Note that the constructor of PropagatorOperator takes as arguments the
# response R and noise covariance N operating on signal space and signal
# covariance operating on harmonic space.
class PropagatorOperator(ift.InversionEnabler, ift.EndomorphicOperator):
def __init__(self, R, N, Sh, inverter):
ift.InversionEnabler.__init__(self, inverter)
ift.EndomorphicOperator.__init__(self)
self.R = R
self.N = N
self.Sh = Sh
self.fft = ift.FFTOperator(R.domain, target=Sh.domain[0])
def _inverse_times(self, x):
return self.R.adjoint_times(self.N.inverse_times(self.R(x))) \
+ self.fft.adjoint_times(self.Sh.inverse_times(self.fft(x)))
@property
def domain(self):
return self.R.domain
@property
def unitary(self):
return False
@property
def self_adjoint(self):
return True
if __name__ == "__main__": if __name__ == "__main__":
# Set up physical constants # Set up physical constants
# Total length of interval or volume the field lives on, e.g. in meters # Total length of interval or volume the field lives on, e.g. in meters
...@@ -85,6 +55,6 @@ if __name__ == "__main__": ...@@ -85,6 +55,6 @@ if __name__ == "__main__":
j = R.adjoint_times(N.inverse_times(d)) j = R.adjoint_times(N.inverse_times(d))
IC = ift.GradientNormController(iteration_limit=500, tol_abs_gradnorm=0.1) IC = ift.GradientNormController(iteration_limit=500, tol_abs_gradnorm=0.1)
inverter = ift.ConjugateGradient(controller=IC) inverter = ift.ConjugateGradient(controller=IC)
D = PropagatorOperator(Sh=Sh, N=N, R=R, inverter=inverter) D = (R.adjoint*N.inverse*R + fft.adjoint*Sh.inverse*fft).inverse
D = ift.InversionEnabler(D, inverter)
m = D(j) m = D(j)
...@@ -71,7 +71,7 @@ if __name__ == "__main__": ...@@ -71,7 +71,7 @@ if __name__ == "__main__":
n_samples = 50 n_samples = 50
for i in range(n_samples): for i in range(n_samples):
sample = fft(D.generate_posterior_sample(m)) sample = fft(D.generate_posterior_sample() + m)
sample_variance += sample**2 sample_variance += sample**2
sample_mean += sample sample_mean += sample
sample_mean /= n_samples sample_mean /= n_samples
......
from ..operators import EndomorphicOperator, InversionEnabler, DiagonalOperator from ..operators import EndomorphicOperator, InversionEnabler, DiagonalOperator
class CriticalPowerCurvature(InversionEnabler, EndomorphicOperator): class CriticalPowerCurvature(EndomorphicOperator):
"""The curvature of the CriticalPowerEnergy. """The curvature of the CriticalPowerEnergy.
This operator implements the second derivative of the This operator implements the second derivative of the
...@@ -17,22 +17,17 @@ class CriticalPowerCurvature(InversionEnabler, EndomorphicOperator): ...@@ -17,22 +17,17 @@ class CriticalPowerCurvature(InversionEnabler, EndomorphicOperator):
""" """
def __init__(self, theta, T, inverter): def __init__(self, theta, T, inverter):
EndomorphicOperator.__init__(self) super(CriticalPowerCurvature, self).__init__()
self._theta = DiagonalOperator(theta) theta = DiagonalOperator(theta)
InversionEnabler.__init__(self, inverter, self._theta.inverse_times) self._op = InversionEnabler(T+theta, inverter, theta.inverse_times)
self._T = T
def _times(self, x):
return self._T(x) + self._theta(x)
@property @property
def domain(self): def capability(self):
return self._theta.domain return self._op.capability
@property def apply(self, x, mode):
def self_adjoint(self): return self._op.apply(x, mode)
return True
@property @property
def unitary(self): def domain(self):
return False return self._op.domain
...@@ -70,7 +70,7 @@ class CriticalPowerEnergy(Energy): ...@@ -70,7 +70,7 @@ class CriticalPowerEnergy(Energy):
if self.D is not None: if self.D is not None:
w = Field.zeros(self.position.domain, dtype=self.m.dtype) w = Field.zeros(self.position.domain, dtype=self.m.dtype)
for i in range(self.samples): for i in range(self.samples):
sample = self.D.generate_posterior_sample(self.m) sample = self.D.generate_posterior_sample() + self.m
w += P(abs(sample)**2) w += P(abs(sample)**2)
w *= 1./self.samples w *= 1./self.samples
......
...@@ -3,7 +3,7 @@ from ..utilities import memo ...@@ -3,7 +3,7 @@ from ..utilities import memo
from ..field import exp from ..field import exp
class LogNormalWienerFilterCurvature(InversionEnabler, EndomorphicOperator): class LogNormalWienerFilterCurvature(EndomorphicOperator):
"""The curvature of the LogNormalWienerFilterEnergy. """The curvature of the LogNormalWienerFilterEnergy.
This operator implements the second derivative of the This operator implements the second derivative of the
...@@ -21,33 +21,54 @@ class LogNormalWienerFilterCurvature(InversionEnabler, EndomorphicOperator): ...@@ -21,33 +21,54 @@ class LogNormalWienerFilterCurvature(InversionEnabler, EndomorphicOperator):
The prior signal covariance The prior signal covariance
""" """
class _Helper(EndomorphicOperator):
def __init__(self, R, N, S, position, fft4exp):
super(LogNormalWienerFilterCurvature._Helper, self).__init__()
self.R = R
self.N = N
self.S = S
self.position = position
self._fft = fft4exp
@property
def domain(self):
return self.S.domain
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
part1 = self.S.inverse_times(x)
part3 = self._fft.adjoint_times(self._expp_sspace * self._fft(x))
part3 = self._fft.adjoint_times(
self._expp_sspace *
self._fft(self.R.adjoint_times(
self.N.inverse_times(self.R(part3)))))
return part1 + part3
@property
@memo
def _expp_sspace(self):
return exp(self._fft(self.position))
def __init__(self, R, N, S, position, fft4exp, inverter): def __init__(self, R, N, S, position, fft4exp, inverter):
InversionEnabler.__init__(self, inverter) super(LogNormalWienerFilterCurvature, self).__init__()
EndomorphicOperator.__init__(self) self._op = self._Helper(R, N, S, position, fft4exp)
self.R = R self._op = InversionEnabler(self._op, inverter)
self.N = N
self.S = S
self.position = position
self._fft = fft4exp
@property @property
def domain(self): def domain(self):
return self.S.domain return self._op.domain
@property @property
def self_adjoint(self): def capability(self):
return True return self._op.capability
def _times(self, x):
part1 = self.S.inverse_times(x)
part3 = self._fft.adjoint_times(self._expp_sspace * self._fft(x))
part3 = self._fft.adjoint_times(
self._expp_sspace *
self._fft(self.R.adjoint_times(
self.N.inverse_times(self.R(part3)))))
return part1 + part3
@property @property
@memo
def _expp_sspace(self): def _expp_sspace(self):
return exp(self._fft(self.position)) return self._op._op._expp_sspace
def apply(self, x, mode):
return self._op.apply(x, mode)
...@@ -27,7 +27,7 @@ class NoiseEnergy(Energy): ...@@ -27,7 +27,7 @@ class NoiseEnergy(Energy):
if samples is None or samples == 0: if samples is None or samples == 0:
sample_list = [m] sample_list = [m]
else: else:
sample_list = [D.generate_posterior_sample(m) sample_list = [D.generate_posterior_sample() + m
for _ in range(samples)] for _ in range(samples)]
self.sample_list = sample_list self.sample_list = sample_list
self.inverter = inverter self.inverter = inverter
......
...@@ -2,42 +2,60 @@ from ..operators import EndomorphicOperator, InversionEnabler ...@@ -2,42 +2,60 @@ from ..operators import EndomorphicOperator, InversionEnabler
from .response_operators import LinearizedPowerResponse from .response_operators import LinearizedPowerResponse
class NonlinearPowerCurvature(InversionEnabler, EndomorphicOperator): class NonlinearPowerCurvature(EndomorphicOperator):
class _Helper(EndomorphicOperator):
def __init__(self, position, FFT, Instrument, nonlinearity,
Projection, N, T, sample_list):
super(NonlinearPowerCurvature._Helper, self).__init__()
self.N = N
self.FFT = FFT
self.Instrument = Instrument
self.T = T
self.sample_list = sample_list
self.position = position
self.Projection = Projection
self.nonlinearity = nonlinearity
@property
def domain(self):
return self.position.domain
@property
def capability(self):
return self.TIMES
def apply(self, x, mode):
self._check_input(x, mode)
result = None
for sample in self.sample_list:
if result is None:
result = self._sample_times(x, sample)
else:
result += self._sample_times(x, sample)
result *= 1./len(self.sample_list)
return result + self.T(x)
def _sample_times(self, x, sample):
LinearizedResponse = LinearizedPowerResponse(
self.Instrument, self.nonlinearity, self.FFT, self.Projection,
self.position, sample)
return LinearizedResponse.adjoint_times(
self.N.inverse_times(LinearizedResponse(x)))
def __init__(self, position, FFT, Instrument, nonlinearity, def __init__(self, position, FFT, Instrument, nonlinearity,
Projection, N, T, sample_list, inverter): Projection, N, T, sample_list, inverter):
InversionEnabler.__init__(self, inverter) super(NonlinearPowerCurvature, self).__init__()
EndomorphicOperator.__init__(self) self._op = self._Helper(position, FFT, Instrument, nonlinearity,
self.N = N Projection, N, T, sample_list)
self.FFT = FFT self._op = InversionEnabler(self._op, inverter)
self.Instrument = Instrument
self.T = T
self.sample_list = sample_list
self.position = position
self.Projection = Projection
self.nonlinearity = nonlinearity
@property @property
def domain(self): def domain(self):
return self.position.domain return self._op.domain
@property @property
def self_adjoint(self): def capability(self):
return True return self._op.capability
def _times(self, x): def apply(self, x, mode):
result = None return self._op.apply(x, mode)
for sample in self.sample_list:
if result is None:
result = self._sample_times(x, sample)
else:
result += self._sample_times(x, sample)
result *= 1./len(self.sample_list)
return result + self.T(x)
def _sample_times(self, x, sample):
LinearizedResponse = LinearizedPowerResponse(
self.Instrument, self.nonlinearity, self.FFT, self.Projection,
self.position, sample)
return LinearizedResponse.adjoint_times(
self.N.inverse_times(LinearizedResponse(x)))
...@@ -53,7 +53,7 @@ class NonlinearPowerEnergy(Energy): ...@@ -53,7 +53,7 @@ class NonlinearPowerEnergy(Energy):
if samples is None or samples == 0: if samples is None or samples == 0:
sample_list = [m] sample_list = [m]
else: else:
sample_list = [D.generate_posterior_sample(m) sample_list = [D.generate_posterior_sample() + m
for _ in range(samples)] for _ in range(samples)]
self.sample_list = sample_list self.sample_list = sample_list
self.inverter = inverter self.inverter = inverter
......
...@@ -31,6 +31,14 @@ class LinearizedSignalResponse(LinearOperator): ...@@ -31,6 +31,14 @@ class LinearizedSignalResponse(LinearOperator):
def target(self): def target(self):
return self.Instrument.target return self.Instrument.target
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
return self._times(x) if mode & self.TIMES else self._adjoint_times(x)
class LinearizedPowerResponse(LinearOperator): class LinearizedPowerResponse(LinearOperator):
def __init__(self, Instrument, nonlinearity, FFT, Projection, t, m): def __init__(self, Instrument, nonlinearity, FFT, Projection, t, m):
...@@ -70,3 +78,11 @@ class LinearizedPowerResponse(LinearOperator): ...@@ -70,3 +78,11 @@ class LinearizedPowerResponse(LinearOperator):
@property @property
def target(self): def target(self):
return self.Instrument.target return self.Instrument.target
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
return self._times(x) if mode & self.TIMES else self._adjoint_times(x)
...@@ -3,7 +3,7 @@ from ..field import Field, sqrt ...@@ -3,7 +3,7 @@ from ..field import Field, sqrt
from ..sugar import power_analyze, power_synthesize from ..sugar import power_analyze, power_synthesize
class WienerFilterCurvature(InversionEnabler, EndomorphicOperator): class WienerFilterCurvature(EndomorphicOperator):
"""The curvature of the WienerFilterEnergy. """The curvature of the WienerFilterEnergy.
This operator implements the second derivative of the This operator implements the second derivative of the
...@@ -22,26 +22,25 @@ class WienerFilterCurvature(InversionEnabler, EndomorphicOperator): ...@@ -22,26 +22,25 @@ class WienerFilterCurvature(InversionEnabler, EndomorphicOperator):
""" """
def __init__(self, R, N, S, inverter): def __init__(self, R, N, S, inverter):
EndomorphicOperator.__init__(self) super(WienerFilterCurvature, self).__init__()
InversionEnabler.__init__(self, inverter, S.times)
self.R = R self.R = R
self.N = N self.N = N
self.S = S self.S = S
op = R.adjoint*N.inverse*R + S.inverse
self._op = InversionEnabler(op, inverter, S.times)
@property @property
def domain(self): def domain(self):
return self.S.domain return self.S.domain
@property @property
def self_adjoint(self): def capability(self):
return True return self._op.capability
def _times(self, x): def apply(self, x, mode):
res = self.R.adjoint_times(self.N.inverse_times(self.R(x))) return self._op.apply(x, mode)
res += self.S.inverse_times(x)
return res
def generate_posterior_sample(self, mean): def generate_posterior_sample(self):
""" Generates a posterior sample from a Gaussian distribution with """ Generates a posterior sample from a Gaussian distribution with
given mean and covariance. given mean and covariance.
...@@ -49,11 +48,6 @@ class WienerFilterCurvature(InversionEnabler, EndomorphicOperator): ...@@ -49,11 +48,6 @@ class WienerFilterCurvature(InversionEnabler, EndomorphicOperator):
reconstruction of a mock signal in order to obtain residuals of the reconstruction of a mock signal in order to obtain residuals of the
right correlation which are added to the given mean. right correlation which are added to the given mean.
Parameters
----------
mean : Field
the mean of the posterior Gaussian distribution
Returns Returns
------- -------
sample : Field sample : Field
...@@ -74,5 +68,5 @@ class WienerFilterCurvature(InversionEnabler, EndomorphicOperator): ...@@ -74,5 +68,5 @@ class WienerFilterCurvature(InversionEnabler, EndomorphicOperator):
mock_j = self.R.adjoint_times(self.N.inverse_times(mock_data)) mock_j = self.R.adjoint_times(self.N.inverse_times(mock_data))
mock_m = self.inverse_times(mock_j) mock_m = self.inverse_times(mock_j)
sample = mock_signal - mock_m + mean sample = mock_signal - mock_m
return sample return sample
...@@ -27,7 +27,7 @@ class WienerFilterEnergy(Energy): ...@@ -27,7 +27,7 @@ class WienerFilterEnergy(Energy):
self.R = R self.R = R
self.N = N self.N = N
self.S = S self.S = S
self._curvature = WienerFilterCurvature(R, N, S, inverter=inverter) self._curvature = WienerFilterCurvature(R, N, S, inverter)
self._inverter = inverter self._inverter = inverter
if _j is None: if _j is None:
_j = self.R.adjoint_times(self.N.inverse_times(d)) _j = self.R.adjoint_times(self.N.inverse_times(d))
......
...@@ -12,3 +12,7 @@ from .laplace_operator import LaplaceOperator ...@@ -12,3 +12,7 @@ from .laplace_operator import LaplaceOperator
from .smoothness_operator import SmoothnessOperator from .smoothness_operator import SmoothnessOperator
from .power_projection_operator import PowerProjectionOperator from .power_projection_operator import PowerProjectionOperator
from .dof_projection_operator import DOFProjectionOperator from .dof_projection_operator import DOFProjectionOperator
from .chain_operator import ChainOperator
from .sum_operator import SumOperator
from .inverse_operator import InverseOperator
from .adjoint_operator import AdjointOperator
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from .linear_operator import LinearOperator
class AdjointOperator(LinearOperator):
def __init__(self, op):
super(AdjointOperator, self).__init__()
self._op = op
@property
def domain(self):
return self._op.target
@property
def target(self):
return self._op.domain
@property
def capability(self):
return self._adjointCapability[self._op.capability]
def apply(self, x, mode):
return self._op.apply(x, self._adjointMode[mode])
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from .linear_operator import LinearOperator
class ChainOperator(LinearOperator):
def __init__(self, op1, op2):
super(ChainOperator, self).__init__()
if op2.target != op1.domain:
raise ValueError("domain mismatch")
self._op1 = op1
self._op2 = op2
@property
def domain(self):
return self._op2.domain
@property
def target(self):
return self._op1.target
@property
def capability(self):
return self._op1.capability & self._op2.capability
def apply(self, x, mode):
self._check_mode(mode)
if mode == self.TIMES or mode == self.ADJOINT_INVERSE_TIMES:
return self._op1.apply(self._op2.apply(x, mode), mode)
return self._op2.apply(self._op1.apply(x, mode), mode)
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@