Commit 9f9b0ed3 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'NIFTy_4' into diag_hack

parents a038d7e4 6c61cbec
Pipeline #26506 passed with stage
in 5 minutes and 28 seconds
...@@ -51,7 +51,7 @@ if __name__ == "__main__": ...@@ -51,7 +51,7 @@ if __name__ == "__main__":
IC = ift.GradientNormController(name="inverter", iteration_limit=500, IC = ift.GradientNormController(name="inverter", iteration_limit=500,
tol_abs_gradnorm=0.1) tol_abs_gradnorm=0.1)
inverter = ift.ConjugateGradient(controller=IC) inverter = ift.ConjugateGradient(controller=IC)
D = (R.adjoint*N.inverse*R + Sh.inverse).inverse D = (ift.SandwichOperator(R, N.inverse) + Sh.inverse).inverse
# MR FIXME: we can/should provide a preconditioner here as well! # MR FIXME: we can/should provide a preconditioner here as well!
D = ift.InversionEnabler(D, inverter) D = ift.InversionEnabler(D, inverter)
m = D(j) m = D(j)
......
...@@ -70,7 +70,7 @@ class NonlinearPowerEnergy(Energy): ...@@ -70,7 +70,7 @@ class NonlinearPowerEnergy(Energy):
if samples is None or samples == 0: if samples is None or samples == 0:
xi_sample_list = [xi] xi_sample_list = [xi]
else: else:
xi_sample_list = [D.draw_sample() + xi xi_sample_list = [D.inverse_draw_sample() + xi
for _ in range(samples)] for _ in range(samples)]
self.xi_sample_list = xi_sample_list self.xi_sample_list = xi_sample_list
self.inverter = inverter self.inverter = inverter
......
...@@ -16,12 +16,12 @@ ...@@ -16,12 +16,12 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from ..operators.endomorphic_operator import EndomorphicOperator from ..operators.sandwich_operator import SandwichOperator
from ..operators.inversion_enabler import InversionEnabler from ..operators.inversion_enabler import InversionEnabler
import numpy as np import numpy as np
class WienerFilterCurvature(EndomorphicOperator): def WienerFilterCurvature(R, N, S, inverter):
"""The curvature of the WienerFilterEnergy. """The curvature of the WienerFilterEnergy.
This operator implements the second derivative of the This operator implements the second derivative of the
...@@ -40,32 +40,5 @@ class WienerFilterCurvature(EndomorphicOperator): ...@@ -40,32 +40,5 @@ class WienerFilterCurvature(EndomorphicOperator):
inverter : Minimizer inverter : Minimizer
The minimizer to use during numerical inversion The minimizer to use during numerical inversion
""" """
op = SandwichOperator(R, N.inverse) + S.inverse
def __init__(self, R, N, S, inverter): return InversionEnabler(op, inverter, S.times)
super(WienerFilterCurvature, self).__init__()
self.R = R
self.N = N
self.S = S
op = R.adjoint*N.inverse*R + S.inverse
self._op = InversionEnabler(op, inverter, S.times)
@property
def domain(self):
return self._op.domain
@property
def capability(self):
return self._op.capability
def apply(self, x, mode):
return self._op.apply(x, mode)
def draw_sample(self, dtype=np.float64):
n = self.N.draw_sample(dtype)
s = self.S.draw_sample(dtype)
d = self.R(s) + n
j = self.R.adjoint_times(self.N.inverse_times(d))
m = self.inverse_times(j)
return s - m
...@@ -10,9 +10,10 @@ from .laplace_operator import LaplaceOperator ...@@ -10,9 +10,10 @@ from .laplace_operator import LaplaceOperator
from .smoothness_operator import SmoothnessOperator from .smoothness_operator import SmoothnessOperator
from .power_distributor import PowerDistributor from .power_distributor import PowerDistributor
from .inversion_enabler import InversionEnabler from .inversion_enabler import InversionEnabler
from .sandwich_operator import SandwichOperator
__all__ = ["LinearOperator", "EndomorphicOperator", "ScalingOperator", __all__ = ["LinearOperator", "EndomorphicOperator", "ScalingOperator",
"DiagonalOperator", "HarmonicTransformOperator", "FFTOperator", "DiagonalOperator", "HarmonicTransformOperator", "FFTOperator",
"FFTSmoothingOperator", "GeometryRemover", "FFTSmoothingOperator", "GeometryRemover",
"LaplaceOperator", "SmoothnessOperator", "PowerDistributor", "LaplaceOperator", "SmoothnessOperator", "PowerDistributor",
"InversionEnabler"] "InversionEnabler", "SandwichOperator"]
...@@ -181,10 +181,20 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -181,10 +181,20 @@ class DiagonalOperator(EndomorphicOperator):
return res return res
def draw_sample(self, dtype=np.float64): def draw_sample(self, dtype=np.float64):
if np.issubdtype(self._ldiag.dtype, np.complexfloating): if (np.issubdtype(self._ldiag.dtype, np.complexfloating) or
raise ValueError("cannot draw sample from complex-valued operator") (self._ldiag <= 0.).any()):
raise ValueError("operator not positive definite")
res = Field.from_random(random_type="normal", domain=self._domain, res = Field.from_random(random_type="normal", domain=self._domain,
dtype=dtype) dtype=dtype)
res.local_data[()] *= np.sqrt(self._ldiag) res.local_data[()] *= np.sqrt(self._ldiag)
return res return res
def inverse_draw_sample(self, dtype=np.float64):
if (np.issubdtype(self._ldiag.dtype, np.complexfloating) or
(self._ldiag <= 0.).any()):
raise ValueError("operator not positive definite")
res = Field.from_random(random_type="normal", domain=self._domain,
dtype=dtype)
res.local_data[()] /= np.sqrt(self._ldiag)
return res
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from .linear_operator import LinearOperator from .linear_operator import LinearOperator
import numpy as np
class EndomorphicOperator(LinearOperator): class EndomorphicOperator(LinearOperator):
...@@ -34,3 +35,32 @@ class EndomorphicOperator(LinearOperator): ...@@ -34,3 +35,32 @@ class EndomorphicOperator(LinearOperator):
Returns `self.domain`, because this is also the target domain Returns `self.domain`, because this is also the target domain
for endomorphic operators.""" for endomorphic operators."""
return self.domain return self.domain
def draw_sample(self, dtype=np.float64):
"""Generate a zero-mean sample
Generates a sample from a Gaussian distribution with zero mean and
covariance given by the operator.
Returns
-------
Field
A sample from the Gaussian of given covariance.
"""
raise NotImplementedError
def inverse_draw_sample(self, dtype=np.float64):
"""Generates a zero-mean sample
Generates a sample from a Gaussian distribution with zero mean and
covariance given by the inverse of the operator.
Returns
-------
A sample from the Gaussian of given covariance
"""
if self.capability & self.INVERSE_TIMES:
x = self.draw_sample(dtype)
return self.inverse_times(x)
else:
raise NotImplementedError
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from .linear_operator import LinearOperator from .linear_operator import LinearOperator
import numpy as np
class InverseOperator(LinearOperator): class InverseOperator(LinearOperator):
...@@ -44,3 +45,9 @@ class InverseOperator(LinearOperator): ...@@ -44,3 +45,9 @@ class InverseOperator(LinearOperator):
def apply(self, x, mode): def apply(self, x, mode):
return self._op.apply(x, self._inverseMode[mode]) return self._op.apply(x, self._inverseMode[mode])
def draw_sample(self, dtype=np.float64):
return self._op.inverse_draw_sample(dtype)
def inverse_draw_sample(self, dtype=np.float64):
return self._op.draw_sample(dtype)
...@@ -21,6 +21,7 @@ from ..minimization.iteration_controller import IterationController ...@@ -21,6 +21,7 @@ from ..minimization.iteration_controller import IterationController
from ..field import Field from ..field import Field
from ..logger import logger from ..logger import logger
from .linear_operator import LinearOperator from .linear_operator import LinearOperator
import numpy as np
class InversionEnabler(LinearOperator): class InversionEnabler(LinearOperator):
...@@ -74,3 +75,15 @@ class InversionEnabler(LinearOperator): ...@@ -74,3 +75,15 @@ class InversionEnabler(LinearOperator):
if stat != IterationController.CONVERGED: if stat != IterationController.CONVERGED:
logger.warning("Error detected during operator inversion") logger.warning("Error detected during operator inversion")
return r.position return r.position
def draw_sample(self, dtype=np.float64):
try:
return self._op.draw_sample(dtype)
except:
return self(self._op.inverse_draw_sample(dtype))
def inverse_draw_sample(self, dtype=np.float64):
try:
return self._op.inverse_draw_sample(dtype)
except:
return self.inverse_times(self._op.draw_sample(dtype))
...@@ -264,16 +264,3 @@ class LinearOperator(NiftyMetaBase()): ...@@ -264,16 +264,3 @@ class LinearOperator(NiftyMetaBase()):
self._check_mode(mode) self._check_mode(mode)
if x.domain != self._dom(mode): if x.domain != self._dom(mode):
raise ValueError("The operator's and field's domains don't match.") raise ValueError("The operator's and field's domains don't match.")
def draw_sample(self):
"""Generate a zero-mean sample
Generates a sample from a Gaussian distribution with zero mean and
covariance given by the operator.
Returns
-------
Field
A sample from the Gaussian of given covariance.
"""
raise NotImplementedError
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from .endomorphic_operator import EndomorphicOperator
import numpy as np
class SandwichOperator(EndomorphicOperator):
"""Operator which is equivalent to the expression `bun.adjoint*cheese*bun`.
Parameters
----------
bun: LinearOperator
the bun part
cheese: EndomorphicOperator
the cheese part
"""
def __init__(self, bun, cheese):
super(SandwichOperator, self).__init__()
self._bun = bun
self._cheese = cheese
self._op = bun.adjoint*cheese*bun
@property
def domain(self):
return self._op.domain
@property
def capability(self):
return self._op.capability
def apply(self, x, mode):
return self._op.apply(x, mode)
def draw_sample(self, dtype=np.float64):
return self._bun.adjoint_times(self._cheese.draw_sample(dtype))
...@@ -100,10 +100,14 @@ class ScalingOperator(EndomorphicOperator): ...@@ -100,10 +100,14 @@ class ScalingOperator(EndomorphicOperator):
raise ValueError("Operator not positive definite") raise ValueError("Operator not positive definite")
return sample * np.sqrt(self._factor) return sample * np.sqrt(self._factor)
def _sample_helper(self, fct, dtype):
if fct.imag != 0. or fct.real <= 0.:
raise ValueError("operator not positive definite")
return Field.from_random(
random_type="normal", domain=self._domain, std=fct, dtype=dtype)
def draw_sample(self, dtype=np.float64): def draw_sample(self, dtype=np.float64):
if self._factor.imag != 0. or self._factor.real <= 0.: return self._sample_helper(np.sqrt(self._factor), dtype)
raise ValueError("Operator not positive definite")
return Field.from_random(random_type="normal", def inverse_draw_sample(self, dtype=np.float64):
domain=self._domain, return self._sample_helper(1./np.sqrt(self._factor), dtype)
std=np.sqrt(self._factor),
dtype=dtype)
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from .linear_operator import LinearOperator from .linear_operator import LinearOperator
import numpy as np
class SumOperator(LinearOperator): class SumOperator(LinearOperator):
...@@ -141,3 +142,9 @@ class SumOperator(LinearOperator): ...@@ -141,3 +142,9 @@ class SumOperator(LinearOperator):
else: else:
res += op.apply(x, mode) res += op.apply(x, mode)
return res return res
def draw_sample(self, dtype=np.float64):
res = self._ops[0].draw_sample(dtype)
for op in self._ops[1:]:
res += op.draw_sample(dtype)
return res
...@@ -51,7 +51,7 @@ class StatCalculator(object): ...@@ -51,7 +51,7 @@ class StatCalculator(object):
def probe_with_posterior_samples(op, post_op, nprobes): def probe_with_posterior_samples(op, post_op, nprobes):
sc = StatCalculator() sc = StatCalculator()
for i in range(nprobes): for i in range(nprobes):
sample = post_op(op.draw_sample()) sample = post_op(op.inverse_draw_sample())
sc.add(sample) sc.add(sample)
if nprobes == 1: if nprobes == 1:
......
...@@ -84,7 +84,7 @@ class Noise_Energy_Tests(unittest.TestCase): ...@@ -84,7 +84,7 @@ class Noise_Energy_Tests(unittest.TestCase):
S=S, S=S,
inverter=inverter).curvature inverter=inverter).curvature
res_sample_list = [d - R(f(ht(C.draw_sample() + xi))) res_sample_list = [d - R(f(ht(C.inverse_draw_sample() + xi)))
for _ in range(10)] for _ in range(10)]
energy0 = ift.library.NoiseEnergy(eta0, alpha, q, res_sample_list) energy0 = ift.library.NoiseEnergy(eta0, alpha, q, res_sample_list)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment