Commit 5d68a885 authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'NIFTy_5' into new_los

parents f1caff21 aabb0370
......@@ -8,6 +8,7 @@ stages:
- build_docker
- test
- release
- demo_runs
build_docker_from_scratch:
only:
......@@ -58,3 +59,142 @@ pages:
- public
only:
- NIFTy_5
before_script:
- export MPLBACKEND="agg"
run_critical_filtering:
stage: demo_runs
script:
- ls
- python setup.py install --user -f
- python3 setup.py install --user -f
- python demos/critical_filtering.py
- python3 demos/critical_filtering.py
artifacts:
paths:
- '*.png'
run_nonlinear_critical_filter:
stage: demo_runs
script:
- python setup.py install --user -f
- python3 setup.py install --user -f
- python demos/nonlinear_critical_filter.py
- python3 demos/nonlinear_critical_filter.py
artifacts:
paths:
- '*.png'
run_nonlinear_wiener_filter:
stage: demo_runs
script:
- python setup.py install --user -f
- python3 setup.py install --user -f
- python demos/nonlinear_wiener_filter.py
- python3 demos/nonlinear_wiener_filter.py
only:
- run_demos
artifacts:
paths:
- '*.png'
run_poisson_demo:
stage: demo_runs
script:
- python setup.py install --user -f
- python3 setup.py install --user -f
- python demos/poisson_demo.py
- python3 demos/poisson_demo.py
artifacts:
paths:
- '*.png'
run_probing:
stage: demo_runs
script:
- python setup.py install --user -f
- python3 setup.py install --user -f
- python demos/probing.py
- python3 demos/probing.py
artifacts:
paths:
- '*.png'
run_sampling:
stage: demo_runs
script:
- python setup.py install --user -f
- python3 setup.py install --user -f
- python demos/sampling.py
- python3 demos/sampling.py
artifacts:
paths:
- '*.png'
run_tomography:
stage: demo_runs
script:
- python setup.py install --user -f
- python3 setup.py install --user -f
- python demos/tomography.py
- python3 demos/tomography.py
artifacts:
paths:
- '*.png'
run_wiener_filter_data_space_noiseless:
stage: demo_runs
script:
- python setup.py install --user -f
- python3 setup.py install --user -f
- python demos/wiener_filter_data_space_noiseless.py
- python3 demos/wiener_filter_data_space_noiseless.py
artifacts:
paths:
- '*.png'
run_wiener_filter_easy.py:
stage: demo_runs
script:
- python setup.py install --user -f
- python3 setup.py install --user -f
- python demos/wiener_filter_easy.py
- python3 demos/wiener_filter_easy.py
artifacts:
paths:
- '*.png'
run_wiener_filter_via_curvature.py:
stage: demo_runs
script:
- pip install --user numericalunits
- pip3 install --user numericalunits
- python setup.py install --user -f
- python3 setup.py install --user -f
- python demos/wiener_filter_via_curvature.py
- python3 demos/wiener_filter_via_curvature.py
artifacts:
paths:
- '*.png'
run_wiener_filter_via_hamiltonian.py:
stage: demo_runs
script:
- python setup.py install --user -f
- python3 setup.py install --user -f
- python demos/wiener_filter_via_hamiltonian.py
- python3 demos/wiener_filter_via_hamiltonian.py
artifacts:
paths:
- '*.png'
run_ipynb:
stage: demo_runs
script:
- python setup.py install --user -f
- python3 setup.py install --user -f
- jupyter nbconvert --execute --ExecutePreprocessor.timeout=None demos/Wiener_Filter.ipynb
artifacts:
paths:
- '*.png'
......@@ -24,6 +24,11 @@ RUN apt-get update && apt-get install -y \
&& pip install coverage \
&& rm -rf /var/lib/apt/lists/*
# Needed for demos to be running
RUN apt-get update && apt-get install -y python-matplotlib python3-matplotlib \
&& python3 -m pip install --upgrade pip && python3 -m pip install jupyter && python -m pip install --upgrade pip && python -m pip install jupyter \
&& rm -rf /var/lib/apt/lists/*
# Create user (openmpi does not like to be run as root)
RUN useradd -ms /bin/bash testinguser
USER testinguser
......
NIFTy - Numerical Information Field Theory
==========================================
[![build status](https://gitlab.mpcdf.mpg.de/ift/NIFTy/badges/NIFTy_4/build.svg)](https://gitlab.mpcdf.mpg.de/ift/NIFTy/commits/NIFTy_4)
[![coverage report](https://gitlab.mpcdf.mpg.de/ift/NIFTy/badges/NIFTy_4/coverage.svg)](https://gitlab.mpcdf.mpg.de/ift/NIFTy/commits/NIFTy_4)
[![build status](https://gitlab.mpcdf.mpg.de/ift/NIFTy/badges/NIFTy_5/build.svg)](https://gitlab.mpcdf.mpg.de/ift/NIFTy/commits/NIFTy_5)
[![coverage report](https://gitlab.mpcdf.mpg.de/ift/NIFTy/badges/NIFTy_5/coverage.svg)](https://gitlab.mpcdf.mpg.de/ift/NIFTy/commits/NIFTy_5)
**NIFTy** project homepage:
[http://ift.pages.mpcdf.de/NIFTy](http://ift.pages.mpcdf.de/NIFTy)
......@@ -62,7 +62,7 @@ distributions, the "apt" lines will need slight changes.
NIFTy5 and its mandatory dependencies can be installed via:
sudo apt-get install git libfftw3-dev python python-pip python-dev
pip install --user git+https://gitlab.mpcdf.mpg.de/ift/NIFTy.git@NIFTy_4
pip install --user git+https://gitlab.mpcdf.mpg.de/ift/NIFTy.git@NIFTy_5
(Note: If you encounter problems related to `pyFFTW`, make sure that you are
using a pip-installed `pyFFTW` package. Unfortunately, some distributions are
......
......@@ -132,6 +132,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true,
"slideshow": {
"slide_type": "-"
}
......@@ -169,10 +170,9 @@
"def Curvature(R, N, Sh):\n",
" IC = ift.GradientNormController(iteration_limit=50000,\n",
" tol_abs_gradnorm=0.1)\n",
" inverter = ift.ConjugateGradient(controller=IC)\n",
" # WienerFilterCurvature is (R.adjoint*N.inverse*R + Sh.inverse) plus some handy\n",
" # helper methods.\n",
" return ift.library.WienerFilterCurvature(R,N,Sh,inverter)"
" return ift.library.WienerFilterCurvature(R,N,Sh,iteration_controller=IC,iteration_controller_sampling=IC)"
]
},
{
......@@ -223,7 +223,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"s_space = ift.RGSpace(N_pixels)\n",
......@@ -708,7 +710,7 @@
"\n",
"https://gitlab.mpcdf.mpg.de/ift/NIFTy\n",
"\n",
"NIFTy v4 **more or less stable!**"
"NIFTy v5 **more or less stable!**"
]
}
],
......@@ -729,7 +731,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.12"
"version": "2.7.15"
}
},
"nbformat": 4,
......
......@@ -85,15 +85,13 @@ if __name__ == "__main__":
LS = ift.LineSearchStrongWolfe(c2=0.02)
minimizer = ift.RelaxedNewton(IC1, line_searcher=LS)
ICI = ift.GradientNormController(iteration_limit=500,
tol_abs_gradnorm=1e-3)
inverter = ift.ConjugateGradient(controller=ICI)
IC = ift.GradientNormController(iteration_limit=500,
tol_abs_gradnorm=1e-3)
for i in range(20):
power0 = Distributor(ift.exp(0.5*t0))
map0_energy = ift.library.NonlinearWienerFilterEnergy(
m0, d, MeasurementOperator, nonlinearity, HT, power0, N, S,
inverter=inverter)
m0, d, MeasurementOperator, nonlinearity, HT, power0, N, S, IC)
# Minimization with chosen minimizer
map0_energy, convergence = minimizer(map0_energy)
......@@ -106,7 +104,8 @@ if __name__ == "__main__":
power0_energy = ift.library.NonlinearPowerEnergy(
position=t0, d=d, N=N, xi=m0, D=D0, ht=HT,
Instrument=MeasurementOperator, nonlinearity=nonlinearity,
Distributor=Distributor, sigma=1., samples=2, inverter=inverter)
Distributor=Distributor, sigma=1., samples=2,
iteration_controller=IC)
power0_energy = minimizer(power0_energy)[0]
......
......@@ -78,15 +78,13 @@ if __name__ == "__main__":
LS = ift.LineSearchStrongWolfe(c2=0.02)
minimizer = ift.RelaxedNewton(IC1, line_searcher=LS)
ICI = ift.GradientNormController(iteration_limit=500,
tol_abs_gradnorm=1e-3)
inverter = ift.ConjugateGradient(controller=ICI)
IC = ift.GradientNormController(iteration_limit=500,
tol_abs_gradnorm=1e-3)
for i in range(20):
power0 = Distributor(ift.exp(0.5*t0))
map0_energy = ift.library.NonlinearWienerFilterEnergy(
m0, d, MeasurementOperator, nonlinearity, HT, power0, N, S,
inverter=inverter)
m0, d, MeasurementOperator, nonlinearity, HT, power0, N, S, IC)
# Minimization with chosen minimizer
map0_energy, convergence = minimizer(map0_energy)
......@@ -99,7 +97,7 @@ if __name__ == "__main__":
power0_energy = ift.library.NonlinearPowerEnergy(
position=t0, d=d, N=N, xi=m0, D=D0, ht=HT,
Instrument=MeasurementOperator, nonlinearity=nonlinearity,
Distributor=Distributor, sigma=1., samples=2, inverter=inverter)
Distributor=Distributor, sigma=1., samples=2, iteration_controller=IC)
power0_energy = minimizer(power0_energy)[0]
......
......@@ -52,14 +52,13 @@ if __name__ == "__main__":
LS = ift.LineSearchStrongWolfe(c2=0.02)
minimizer = ift.RelaxedNewton(IC1, line_searcher=LS)
ICI = ift.GradientNormController(iteration_limit=2000,
tol_abs_gradnorm=1e-3)
inverter = ift.ConjugateGradient(controller=ICI)
IC = ift.GradientNormController(iteration_limit=2000,
tol_abs_gradnorm=1e-3)
# initial guess
m = ift.full(h_space, 1e-7)
map_energy = ift.library.NonlinearWienerFilterEnergy(
m, d, R, nonlinearity, HT, power, N, S, inverter=inverter)
m, d, R, nonlinearity, HT, power, N, S, IC)
# Minimization with chosen minimizer
map_energy, convergence = minimizer(map_energy)
......
......@@ -76,10 +76,9 @@ if __name__ == "__main__":
ctrl = ift.GradientNormController(name="inverter", tol_abs_gradnorm=0.1)
sampling_ctrl = ift.GradientNormController(name="sampling",
tol_abs_gradnorm=1e2)
inverter = ift.ConjugateGradient(controller=ctrl)
sampling_inverter = ift.ConjugateGradient(controller=sampling_ctrl)
wiener_curvature = ift.library.WienerFilterCurvature(
S=S, N=N, R=R, inverter=inverter, sampling_inverter=sampling_inverter)
S=S, N=N, R=R, iteration_controller=ctrl,
iteration_controller_sampling=sampling_ctrl)
m_k = wiener_curvature.inverse_times(j)
m = ht(m_k)
......
......@@ -50,10 +50,9 @@ if __name__ == "__main__":
ctrl = ift.GradientNormController(name="inverter", tol_abs_gradnorm=1e-2)
sampling_ctrl = ift.GradientNormController(name="sampling",
tol_abs_gradnorm=2e1)
inverter = ift.ConjugateGradient(controller=ctrl)
sampling_inverter = ift.ConjugateGradient(controller=sampling_ctrl)
wiener_curvature = ift.library.WienerFilterCurvature(
S=S, N=N, R=R, inverter=inverter, sampling_inverter=sampling_inverter)
S=S, N=N, R=R, iteration_controller=ctrl,
iteration_controller_sampling=sampling_ctrl)
m_k = wiener_curvature.inverse_times(j)
m = ht(m_k)
......
......@@ -81,9 +81,8 @@ if __name__ == "__main__":
j = R.adjoint_times(N.inverse_times(d))
IC = ift.GradientNormController(name="inverter", iteration_limit=500,
tol_abs_gradnorm=1e-3)
inverter = ift.ConjugateGradient(controller=IC)
D = (ift.SandwichOperator.make(R, N.inverse) + Phi_h.inverse).inverse
D = ift.InversionEnabler(D, inverter, approximation=Phi_h)
D = ift.InversionEnabler(D, IC, approximation=Phi_h)
m = HT(D(j))
# Uncertainty
......@@ -116,8 +115,7 @@ if __name__ == "__main__":
# initial guess
psi0 = ift.full(h_domain, 1e-7)
energy = ift.library.PoissonEnergy(psi0, data, R0, nonlin, HT, Phi_h,
inverter)
energy = ift.library.PoissonEnergy(psi0, data, R0, nonlin, HT, Phi_h, IC)
IC1 = ift.GradientNormController(name="IC1", iteration_limit=200,
tol_abs_gradnorm=1e-4)
minimizer = ift.RelaxedNewton(IC1)
......
......@@ -39,17 +39,17 @@ N_iter = 100
tol = 1e-3
IC = ift.GradientNormController(tol_abs_gradnorm=tol, iteration_limit=N_iter)
inverter = ift.ConjugateGradient(IC)
curv = ift.library.WienerFilterCurvature(S=S, N=N, R=R_p, inverter=inverter,
sampling_inverter=inverter)
curv = ift.library.WienerFilterCurvature(S=S, N=N, R=R_p,
iteration_controller=IC,
iteration_controller_sampling=IC)
m_xi = curv.inverse_times(j)
samps_long = [curv.draw_sample(from_inverse=True) for i in range(N_samps)]
tol = 1e2
IC = ift.GradientNormController(tol_abs_gradnorm=tol, iteration_limit=N_iter)
inverter = ift.ConjugateGradient(IC)
curv = ift.library.WienerFilterCurvature(S=S, N=N, R=R_p, inverter=inverter,
sampling_inverter=inverter)
curv = ift.library.WienerFilterCurvature(S=S, N=N, R=R_p,
iteration_controller=IC,
iteration_controller_sampling=IC)
samps_short = [curv.draw_sample(from_inverse=True) for i in range(N_samps)]
# Compute mean
......
......@@ -36,8 +36,8 @@ if __name__ == "__main__":
j = Rh.adjoint_times(N.inverse_times(d))
ctrl = ift.GradientNormController(name="Iter", tol_abs_gradnorm=1e-10,
iteration_limit=300)
inverter = ift.ConjugateGradient(controller=ctrl)
Di = ift.library.WienerFilterCurvature(S=S, R=Rh, N=N, inverter=inverter)
Di = ift.library.WienerFilterCurvature(S=S, R=Rh, N=N,
iteration_controller=ctrl)
mh = Di.inverse_times(j)
m = ht(mh)
......
......@@ -98,10 +98,9 @@ if __name__ == "__main__":
IC = ift.GradientNormController(name="inverter", iteration_limit=1000,
tol_abs_gradnorm=0.0001)
inverter = ift.ConjugateGradient(controller=IC)
# setting up measurement precision matrix M
M = (ift.SandwichOperator.make(R.adjoint, Sh) + N)
M = ift.InversionEnabler(M, inverter)
M = ift.InversionEnabler(M, IC)
m = Sh(R.adjoint(M.inverse_times(d)))
# Plotting
......
......@@ -52,9 +52,8 @@ if __name__ == "__main__":
j = R.adjoint_times(N.inverse_times(d))
IC = ift.GradientNormController(name="inverter", iteration_limit=500,
tol_abs_gradnorm=0.1)
inverter = ift.ConjugateGradient(controller=IC)
D = (ift.SandwichOperator.make(R, N.inverse) + Sh.inverse).inverse
D = ift.InversionEnabler(D, inverter, approximation=Sh)
D = ift.InversionEnabler(D, IC, approximation=Sh)
m = D(j)
# Plotting
......
......@@ -78,9 +78,8 @@ if __name__ == "__main__":
j = R.adjoint_times(N.inverse_times(data))
ctrl = ift.GradientNormController(
name="inverter", tol_abs_gradnorm=1e-5/(nu.K*(nu.m**dimensionality)))
inverter = ift.ConjugateGradient(controller=ctrl)
wiener_curvature = ift.library.WienerFilterCurvature(
S=S, N=N, R=R, inverter=inverter)
wiener_curvature = ift.library.WienerFilterCurvature(S=S, N=N, R=R,
iteration_controller=ctrl)
m = wiener_curvature.inverse_times(j)
m_s = HT(m)
......
......@@ -47,15 +47,14 @@ if __name__ == "__main__":
# Choose minimization strategy
ctrl = ift.GradientNormController(name="inverter", tol_abs_gradnorm=0.1)
inverter = ift.ConjugateGradient(controller=ctrl)
controller = ift.GradientNormController(name="min", tol_abs_gradnorm=0.1)
minimizer = ift.RelaxedNewton(controller=controller)
m0 = ift.full(h_space, 0.)
# Initialize Wiener filter energy
energy = ift.library.WienerFilterEnergy(position=m0, d=d, R=R, N=N, S=S,
inverter=inverter,
sampling_inverter=inverter)
iteration_controller=ctrl,
iteration_controller_sampling=ctrl)
energy, convergence = minimizer(energy)
m = energy.position
......
......@@ -68,8 +68,6 @@ class DomainTuple(object):
"""
if isinstance(domain, DomainTuple):
return domain
if isinstance(domain, dict):
return domain
domain = DomainTuple._parse_domain(domain)
obj = DomainTuple._tupleCache.get(domain)
if obj is not None:
......@@ -126,15 +124,28 @@ class DomainTuple(object):
return self._dom.__hash__()
def __eq__(self, x):
if not isinstance(x, DomainTuple):
x = DomainTuple.make(x)
if self is x:
return True
return self._dom == x._dom
x = DomainTuple.make(x)
return self is x
def __ne__(self, x):
return not self.__eq__(x)
def compatibleTo(self, x):
return self.__eq__(x)
def subsetOf(self, x):
return self.__eq__(x)
def unitedWith(self, x):
if self is x:
return self
x = DomainTuple.make(x)
if self is not x:
raise ValueError("domain mismatch")
return self
def __str__(self):
res = "DomainTuple, len: " + str(len(self))
for i in self:
......
......@@ -44,7 +44,7 @@ def _get_acceptable_energy(E):
return E2
def check_value_gradient_consistency(E, tol=1e-6, ntries=100):
def check_value_gradient_consistency(E, tol=1e-8, ntries=100):
for _ in range(ntries):
E2 = _get_acceptable_energy(E)
val = E.value
......@@ -54,7 +54,8 @@ def check_value_gradient_consistency(E, tol=1e-6, ntries=100):
for i in range(50):
Emid = E.at(E.position + 0.5*dir)
dirder = Emid.gradient.vdot(dir)/dirnorm
if abs((E2.value-val)/dirnorm-dirder) < tol:
xtol = tol*Emid.gradient_norm
if abs((E2.value-val)/dirnorm - dirder) < xtol:
break
dir *= 0.5
dirnorm *= 0.5
......@@ -64,7 +65,7 @@ def check_value_gradient_consistency(E, tol=1e-6, ntries=100):
# E = Enext
def check_value_gradient_curvature_consistency(E, tol=1e-6, ntries=100):
def check_value_gradient_curvature_consistency(E, tol=1e-8, ntries=100):
for _ in range(ntries):
E2 = _get_acceptable_energy(E)
val = E.value
......@@ -75,8 +76,9 @@ def check_value_gradient_curvature_consistency(E, tol=1e-6, ntries=100):
Emid = E.at(E.position + 0.5*dir)
dirder = Emid.gradient.vdot(dir)/dirnorm
dgrad = Emid.curvature(dir)/dirnorm
if abs((E2.value-val)/dirnorm-dirder) < tol and \
(abs((E2.gradient-E.gradient)/dirnorm-dgrad) < tol).all():
xtol = tol*Emid.gradient_norm
if abs((E2.value-val)/dirnorm - dirder) < xtol and \
(abs((E2.gradient-E.gradient)/dirnorm-dgrad) < xtol).all():
break
dir *= 0.5
dirnorm *= 0.5
......
......@@ -720,7 +720,9 @@ class Field(object):
self._domain.__str__() + \
"\n- val = " + repr(self.val)
def equivalent(self, other):
def isEquivalentTo(self, other):
"""Determines (as quickly as possible) whether `self`'s content is
identical to `other`'s content."""
if self is other:
return True
if not isinstance(other, Field):
......@@ -729,6 +731,11 @@ class Field(object):
return False
return (self._val == other._val).all()
def isSubsetOf(self, other):
"""Identical to `Field.isEquivalentTo()`. This method is provided for
easier interoperability with `MultiField`."""
return self.isEquivalentTo(other)
for op in ["__add__", "__radd__", "__iadd__",
"__sub__", "__rsub__", "__isub__",
......@@ -740,6 +747,7 @@ for op in ["__add__", "__radd__", "__iadd__",
"__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]:
def func(op):
def func2(self, other):
global COUNTER
# if other is a field, make sure that the domains match
if isinstance(other, Field):
if other._domain != self._domain:
......
......@@ -63,7 +63,7 @@ class NonlinearPowerEnergy(Energy):
# MR FIXME: docstring incomplete and outdated
def __init__(self, position, d, N, xi, D, ht, Instrument, nonlinearity,
Distributor, sigma=0., samples=3, xi_sample_list=None,
inverter=None):
iteration_controller=None):
super(NonlinearPowerEnergy, self).__init__(position)
self.xi = xi
self.D = D
......@@ -83,7 +83,7 @@ class NonlinearPowerEnergy(Energy):
xi_sample_list = [D.draw_sample(from_inverse=True) + xi
for _ in range(samples)]
self.xi_sample_list = xi_sample_list
self.inverter = inverter
self._ic = iteration_controller
A = Distributor(exp(.5 * position))
......@@ -118,7 +118,7 @@ class NonlinearPowerEnergy(Energy):
self.Distributor, sigma=self.sigma,
samples=len(self.xi_sample_list),
xi_sample_list=self.xi_sample_list,
inverter=self.inverter)
iteration_controller=self._ic)
@property
def value(self):
......@@ -139,4 +139,4 @@ class NonlinearPowerEnergy(Energy):
op = LinearizedResponse.adjoint*self.N.inverse*LinearizedResponse
result = op if result is None else result + op
result = result*(1./len(self.xi_sample_list)) + self.T
return InversionEnabler(result, self.inverter)
return InversionEnabler(result, self._ic)
......@@ -24,8 +24,8 @@ from ..sugar import makeOp
class NonlinearWienerFilterEnergy(Energy):
def __init__(self, position, d, Instrument, nonlinearity, ht, power, N, S,
inverter=None,
sampling_inverter=None):
iteration_controller=None,
iteration_controller_sampling=None):
super(NonlinearWienerFilterEnergy, self).__init__(position=position)
self.d = d.lock()
self.Instrument = Instrument
......@@ -37,10 +37,10 @@ class NonlinearWienerFilterEnergy(Energy):
residual = d - Instrument(nonlinearity(m))
self.N = N
self.S = S
self.inverter = inverter
if sampling_inverter is None:
sampling_inverter = inverter
self.sampling_inverter = sampling_inverter
self._ic = iteration_controller
if iteration_controller_sampling is None:
iteration_controller_sampling = self._ic
self._ic_samp = iteration_controller_sampling
t1 = S.inverse_times(position)
t2 = N.inverse_times(residual)
self._value = 0.5 * (position.vdot(t1) + residual.vdot(t2)).real
......@@ -51,7 +51,7 @@ class NonlinearWienerFilterEnergy(Energy):
def at(self, position):
return self.__class__(position, self.d, self.Instrument,
self.nonlinearity, self.ht, self.power, self.N,
self.S, self.inverter)
self.S, self._ic, self._ic_samp)
@property
def value(self):
......@@ -64,5 +64,5 @@ class NonlinearWienerFilterEnergy(Energy):
@property
@memo
def curvature(self):
return WienerFilterCurvature(self.R, self.N, self.S, self.inverter,
self.sampling_inverter)
return WienerFilterCurvature(self.R, self.N, self.S, self._ic,
self._ic_samp)
......@@ -25,9 +25,9 @@ from ..sugar import log
class PoissonEnergy(Energy):
def __init__(self, position, d, Instrument, nonlinearity, ht, Phi_h,
inverter=None):
iteration_controller=None):
super(PoissonEnergy, self).__init__(position=position)
self._inverter = inverter
self._ic = iteration_controller
self._d = d
self._Instrument = Instrument
self._nonlinearity = nonlinearity
......@@ -51,7 +51,7 @@ class PoissonEnergy(Energy):
def at(self, position):
return self.__class__(position, self._d, self._Instrument,
self._nonlinearity, self._ht, self._Phi_h,
self._inverter)
self._ic)
@property
def value(self):
......@@ -63,5 +63,4 @@ class PoissonEnergy(Energy):
@property
def curvature(self):
return InversionEnabler(self._curv, self._inverter,
approximation=self._Phi_h.inverse)
return InversionEnabler(self._curv, self._ic, self._Phi_h.inverse)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from numpy import inf, isnan
from ..minimization import Energy
......@@ -6,15 +24,15 @@ from ..sugar import log, makeOp
class PoissonLogLikelihood(Energy):
def __init__(self, position, lamb, d):
def __init__(self, lamb, d):
"""
s: Sky model object
value = 0.5 * s.vdot(s), i.e. a log-Gauss distribution with unit
covariance
"""
super(PoissonLogLikelihood, self).__init__(position)
self._lamb = lamb.at(position)
super(PoissonLogLikelihood, self).__init__(lamb.position)
self._lamb = lamb
self._d = d
lamb_val = self._lamb.value
......@@ -29,7 +47,7 @@ class PoissonLogLikelihood(Energy):
self._curvature = SandwichOperator.make(self._lamb.gradient, metric)