Commit bfeb39ae authored by Martin Reinecke's avatar Martin Reinecke
Browse files

big progress

parent c10334fd
Pipeline #26497 passed with stage
in 5 minutes and 9 seconds
......@@ -70,7 +70,7 @@ class NonlinearPowerEnergy(Energy):
if samples is None or samples == 0:
xi_sample_list = [xi]
else:
xi_sample_list = [D.draw_inverse_sample() + xi
xi_sample_list = [D.inverse_draw_sample() + xi
for _ in range(samples)]
self.xi_sample_list = xi_sample_list
self.inverter = inverter
......
......@@ -16,12 +16,12 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from ..operators.endomorphic_operator import EndomorphicOperator
from ..operators.sandwich_operator import SandwichOperator
from ..operators.inversion_enabler import InversionEnabler
import numpy as np
class WienerFilterCurvature(EndomorphicOperator):
def WienerFilterCurvature(R, N, S, inverter):
"""The curvature of the WienerFilterEnergy.
This operator implements the second derivative of the
......@@ -40,28 +40,5 @@ class WienerFilterCurvature(EndomorphicOperator):
inverter : Minimizer
The minimizer to use during numerical inversion
"""
def __init__(self, R, N, S, inverter):
super(WienerFilterCurvature, self).__init__()
self.R = R
self.N = N
self.S = S
op = R.adjoint*N.inverse*R + S.inverse
self._op = InversionEnabler(op, inverter, S.times)
@property
def domain(self):
return self._op.domain
@property
def capability(self):
return self._op.capability
def apply(self, x, mode):
return self._op.apply(x, mode)
def draw_sample(self, dtype=np.float64):
n = self.N.inverse_draw_sample(dtype)
s = self.S.inverse_draw_sample(dtype)
return s - self.R.adjoint_times(n)
op = SandwichOperator(R, N.inverse) + S.inverse
return InversionEnabler(op, inverter, S.times)
......@@ -147,7 +147,8 @@ class DiagonalOperator(EndomorphicOperator):
return res
def draw_sample(self, dtype=np.float64):
if np.issubdtype(self._ldiag.dtype, np.complexfloating) or (self._ldiag <= 0.).any():
if (np.issubdtype(self._ldiag.dtype, np.complexfloating) or
(self._ldiag <= 0.).any()):
raise ValueError("operator not positive definite")
res = Field.from_random(random_type="normal", domain=self._domain,
dtype=dtype)
......@@ -155,7 +156,8 @@ class DiagonalOperator(EndomorphicOperator):
return res
def inverse_draw_sample(self, dtype=np.float64):
if np.issubdtype(self._ldiag.dtype, np.complexfloating) or (self._ldiag <= 0.).any():
if (np.issubdtype(self._ldiag.dtype, np.complexfloating) or
(self._ldiag <= 0.).any()):
raise ValueError("operator not positive definite")
res = Field.from_random(random_type="normal", domain=self._domain,
......
......@@ -60,7 +60,7 @@ class EndomorphicOperator(LinearOperator):
A sample from the Gaussian of given covariance
"""
if self.capability & self.INVERSE_TIMES:
x = self.draw_sample(dtype=dtype)
x = self.draw_sample(dtype)
return self.inverse_times(x)
else:
raise NotImplementedError
......@@ -17,6 +17,7 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from .linear_operator import LinearOperator
import numpy as np
class InverseOperator(LinearOperator):
......@@ -44,3 +45,9 @@ class InverseOperator(LinearOperator):
def apply(self, x, mode):
return self._op.apply(x, self._inverseMode[mode])
def draw_sample(self, dtype=np.float64):
return self._op.inverse_draw_sample(dtype)
def inverse_draw_sample(self, dtype=np.float64):
return self._op.draw_sample(dtype)
......@@ -21,6 +21,7 @@ from ..minimization.iteration_controller import IterationController
from ..field import Field
from ..logger import logger
from .linear_operator import LinearOperator
import numpy as np
class InversionEnabler(LinearOperator):
......@@ -74,3 +75,15 @@ class InversionEnabler(LinearOperator):
if stat != IterationController.CONVERGED:
logger.warning("Error detected during operator inversion")
return r.position
def draw_sample(self, dtype=np.float64):
try:
return self._op.draw_sample(dtype)
except:
return self(self._op.inverse_draw_sample(dtype))
def inverse_draw_sample(self, dtype=np.float64):
try:
return self._op.inverse_draw_sample(dtype)
except:
return self.inverse_times(self._op.draw_sample(dtype))
......@@ -17,6 +17,7 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from .endomorphic_operator import EndomorphicOperator
import numpy as np
class SandwichOperator(EndomorphicOperator):
......@@ -47,5 +48,5 @@ class SandwichOperator(EndomorphicOperator):
def apply(self, x, mode):
return self._op.apply(x, mode)
def draw_sample(self):
return self._bun.adjoint_times(self._cheese.draw_sample())
def draw_sample(self, dtype=np.float64):
return self._bun.adjoint_times(self._cheese.draw_sample(dtype))
......@@ -17,6 +17,7 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from .linear_operator import LinearOperator
import numpy as np
class SumOperator(LinearOperator):
......@@ -143,3 +144,9 @@ class SumOperator(LinearOperator):
else:
res += op.apply(x, mode)
return res
def draw_sample(self, dtype=np.float64):
res = self._ops[0].draw_sample(dtype)
for op in self._ops[1:]:
res += op.draw_sample(dtype)
return res
......@@ -51,7 +51,7 @@ class StatCalculator(object):
def probe_with_posterior_samples(op, post_op, nprobes):
sc = StatCalculator()
for i in range(nprobes):
sample = post_op(op.draw_inverse_sample())
sample = post_op(op.inverse_draw_sample())
sc.add(sample)
if nprobes == 1:
......
......@@ -84,7 +84,7 @@ class Noise_Energy_Tests(unittest.TestCase):
S=S,
inverter=inverter).curvature
res_sample_list = [d - R(f(ht(C.draw_inverse_sample() + xi)))
res_sample_list = [d - R(f(ht(C.inverse_draw_sample() + xi)))
for _ in range(10)]
energy0 = ift.library.NoiseEnergy(eta0, alpha, q, res_sample_list)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment