Commit 7b4b15b0 authored by Martin Reinecke's avatar Martin Reinecke

alternative solution

parent ff54ff45
Pipeline #29551 passed with stages
in 3 minutes and 53 seconds
...@@ -31,7 +31,7 @@ d = R(s_x) + n ...@@ -31,7 +31,7 @@ d = R(s_x) + n
R_p = R * FFT * A R_p = R * FFT * A
j = R_p.adjoint(N.inverse(d)) j = R_p.adjoint(N.inverse(d))
D_inv = ift.SandwichOperator(R_p, N.inverse) + S.inverse D_inv = ift.SandwichOperator.make(R_p, N.inverse) + S.inverse
N_samps = 200 N_samps = 200
......
...@@ -80,12 +80,12 @@ if __name__ == "__main__": ...@@ -80,12 +80,12 @@ if __name__ == "__main__":
IC = ift.GradientNormController(name="inverter", iteration_limit=500, IC = ift.GradientNormController(name="inverter", iteration_limit=500,
tol_abs_gradnorm=1e-3) tol_abs_gradnorm=1e-3)
inverter = ift.ConjugateGradient(controller=IC) inverter = ift.ConjugateGradient(controller=IC)
D = (ift.SandwichOperator(R, N.inverse) + Phi_h.inverse).inverse D = (ift.SandwichOperator.make(R, N.inverse) + Phi_h.inverse).inverse
D = ift.InversionEnabler(D, inverter, approximation=Phi_h) D = ift.InversionEnabler(D, inverter, approximation=Phi_h)
m = HT(D(j)) m = HT(D(j))
# Uncertainty # Uncertainty
D = ift.SandwichOperator(aHT, D) # real space propagator D = ift.SandwichOperator.make(aHT, D) # real space propagator
Dhat = ift.probe_with_posterior_samples(D.inverse, None, Dhat = ift.probe_with_posterior_samples(D.inverse, None,
nprobes=nprobes)[1] nprobes=nprobes)[1]
sig = ift.sqrt(Dhat) sig = ift.sqrt(Dhat)
......
...@@ -51,7 +51,7 @@ if __name__ == "__main__": ...@@ -51,7 +51,7 @@ if __name__ == "__main__":
IC = ift.GradientNormController(name="inverter", iteration_limit=500, IC = ift.GradientNormController(name="inverter", iteration_limit=500,
tol_abs_gradnorm=0.1) tol_abs_gradnorm=0.1)
inverter = ift.ConjugateGradient(controller=IC) inverter = ift.ConjugateGradient(controller=IC)
D = (ift.SandwichOperator(R, N.inverse) + Sh.inverse).inverse D = (ift.SandwichOperator.make(R, N.inverse) + Sh.inverse).inverse
D = ift.InversionEnabler(D, inverter, approximation=Sh) D = ift.InversionEnabler(D, inverter, approximation=Sh)
m = D(j) m = D(j)
......
...@@ -46,7 +46,7 @@ class PoissonEnergy(Energy): ...@@ -46,7 +46,7 @@ class PoissonEnergy(Energy):
R1 = Instrument*Rho*ht R1 = Instrument*Rho*ht
self._grad = (phipos + R1.adjoint_times((lam-d)/(lam+eps))).lock() self._grad = (phipos + R1.adjoint_times((lam-d)/(lam+eps))).lock()
self._curv = Phi_h.inverse + SandwichOperator(R1, W) self._curv = Phi_h.inverse + SandwichOperator.make(R1, W)
def at(self, position): def at(self, position):
return self.__class__(position, self._d, self._Instrument, return self.__class__(position, self._d, self._Instrument,
......
...@@ -39,5 +39,5 @@ def WienerFilterCurvature(R, N, S, inverter): ...@@ -39,5 +39,5 @@ def WienerFilterCurvature(R, N, S, inverter):
inverter : Minimizer inverter : Minimizer
The minimizer to use during numerical inversion The minimizer to use during numerical inversion
""" """
op = SandwichOperator(R, N.inverse) + S.inverse op = SandwichOperator.make(R, N.inverse) + S.inverse
return InversionEnabler(op, inverter, S.inverse) return InversionEnabler(op, inverter, S.inverse)
...@@ -25,24 +25,37 @@ from .scaling_operator import ScalingOperator ...@@ -25,24 +25,37 @@ from .scaling_operator import ScalingOperator
class SandwichOperator(EndomorphicOperator): class SandwichOperator(EndomorphicOperator):
"""Operator which is equivalent to the expression `bun.adjoint*cheese*bun`. """Operator which is equivalent to the expression `bun.adjoint*cheese*bun`.
Parameters
----------
bun: LinearOperator
the bun part
cheese: EndomorphicOperator
the cheese part
""" """
def __init__(self, bun, cheese=None): def __init__(self, bun, cheese, op, _callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError
super(SandwichOperator, self).__init__() super(SandwichOperator, self).__init__()
self._bun = bun self._bun = bun
self._cheese = cheese
self._op = op
@staticmethod
def make(bun, cheese=None):
"""Build a SandwichOperator (or something simpler if possible)
Parameters
----------
bun: LinearOperator
the bun part
cheese: EndomorphicOperator
the cheese part
"""
if cheese is None: if cheese is None:
self._cheese = ScalingOperator(1., bun.target) cheese = ScalingOperator(1., bun.target)
self._op = bun.adjoint*bun op = bun.adjoint*bun
else: else:
self._cheese = cheese op = bun.adjoint*cheese*bun
self._op = bun.adjoint*cheese*bun
# if our sandwich is diagonal, we can return immediately
if isinstance(op, (ScalingOperator, DiagonalOperator)):
return op
return SandwichOperator(bun, cheese, op, _callingfrommake=True)
@property @property
def domain(self): def domain(self):
...@@ -56,10 +69,6 @@ class SandwichOperator(EndomorphicOperator): ...@@ -56,10 +69,6 @@ class SandwichOperator(EndomorphicOperator):
return self._op.apply(x, mode) return self._op.apply(x, mode)
def draw_sample(self, from_inverse=False, dtype=np.float64): def draw_sample(self, from_inverse=False, dtype=np.float64):
# Drawing samples from diagonal operators is easy (inverse is possible)
if isinstance(self._op, (ScalingOperator, DiagonalOperator)):
return self._op.draw_sample(from_inverse, dtype)
# Inverse samples from general sandwiches is not possible # Inverse samples from general sandwiches is not possible
if from_inverse: if from_inverse:
raise NotImplementedError( raise NotImplementedError(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment