Commit 7b4b15b0 authored by Martin Reinecke's avatar Martin Reinecke

alternative solution

parent ff54ff45
Pipeline #29551 passed with stages
in 3 minutes and 53 seconds
......@@ -31,7 +31,7 @@ d = R(s_x) + n
R_p = R * FFT * A
j = R_p.adjoint(N.inverse(d))
D_inv = ift.SandwichOperator(R_p, N.inverse) + S.inverse
D_inv = ift.SandwichOperator.make(R_p, N.inverse) + S.inverse
N_samps = 200
......
......@@ -80,12 +80,12 @@ if __name__ == "__main__":
IC = ift.GradientNormController(name="inverter", iteration_limit=500,
tol_abs_gradnorm=1e-3)
inverter = ift.ConjugateGradient(controller=IC)
D = (ift.SandwichOperator(R, N.inverse) + Phi_h.inverse).inverse
D = (ift.SandwichOperator.make(R, N.inverse) + Phi_h.inverse).inverse
D = ift.InversionEnabler(D, inverter, approximation=Phi_h)
m = HT(D(j))
# Uncertainty
D = ift.SandwichOperator(aHT, D) # real space propagator
D = ift.SandwichOperator.make(aHT, D) # real space propagator
Dhat = ift.probe_with_posterior_samples(D.inverse, None,
nprobes=nprobes)[1]
sig = ift.sqrt(Dhat)
......
......@@ -51,7 +51,7 @@ if __name__ == "__main__":
IC = ift.GradientNormController(name="inverter", iteration_limit=500,
tol_abs_gradnorm=0.1)
inverter = ift.ConjugateGradient(controller=IC)
D = (ift.SandwichOperator(R, N.inverse) + Sh.inverse).inverse
D = (ift.SandwichOperator.make(R, N.inverse) + Sh.inverse).inverse
D = ift.InversionEnabler(D, inverter, approximation=Sh)
m = D(j)
......
......@@ -46,7 +46,7 @@ class PoissonEnergy(Energy):
R1 = Instrument*Rho*ht
self._grad = (phipos + R1.adjoint_times((lam-d)/(lam+eps))).lock()
self._curv = Phi_h.inverse + SandwichOperator(R1, W)
self._curv = Phi_h.inverse + SandwichOperator.make(R1, W)
def at(self, position):
return self.__class__(position, self._d, self._Instrument,
......
......@@ -39,5 +39,5 @@ def WienerFilterCurvature(R, N, S, inverter):
inverter : Minimizer
The minimizer to use during numerical inversion
"""
op = SandwichOperator(R, N.inverse) + S.inverse
op = SandwichOperator.make(R, N.inverse) + S.inverse
return InversionEnabler(op, inverter, S.inverse)
......@@ -25,24 +25,37 @@ from .scaling_operator import ScalingOperator
class SandwichOperator(EndomorphicOperator):
"""Operator which is equivalent to the expression `bun.adjoint*cheese*bun`.
Parameters
----------
bun: LinearOperator
the bun part
cheese: EndomorphicOperator
the cheese part
"""
def __init__(self, bun, cheese=None):
def __init__(self, bun, cheese, op, _callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError
super(SandwichOperator, self).__init__()
self._bun = bun
self._cheese = cheese
self._op = op
@staticmethod
def make(bun, cheese=None):
"""Build a SandwichOperator (or something simpler if possible)
Parameters
----------
bun: LinearOperator
the bun part
cheese: EndomorphicOperator
the cheese part
"""
if cheese is None:
self._cheese = ScalingOperator(1., bun.target)
self._op = bun.adjoint*bun
cheese = ScalingOperator(1., bun.target)
op = bun.adjoint*bun
else:
self._cheese = cheese
self._op = bun.adjoint*cheese*bun
op = bun.adjoint*cheese*bun
# if our sandwich is diagonal, we can return immediately
if isinstance(op, (ScalingOperator, DiagonalOperator)):
return op
return SandwichOperator(bun, cheese, op, _callingfrommake=True)
@property
def domain(self):
......@@ -56,10 +69,6 @@ class SandwichOperator(EndomorphicOperator):
return self._op.apply(x, mode)
def draw_sample(self, from_inverse=False, dtype=np.float64):
# Drawing samples from diagonal operators is easy (inverse is possible)
if isinstance(self._op, (ScalingOperator, DiagonalOperator)):
return self._op.draw_sample(from_inverse, dtype)
# Inverse samples from general sandwiches is not possible
if from_inverse:
raise NotImplementedError(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment