Commit 250375bc authored by Martin Reinecke's avatar Martin Reinecke

move curvatures and linearized responses into their respective energy classes

parent 0649322d
from .wiener_filter_energy import WienerFilterEnergy
from .wiener_filter_curvature import WienerFilterCurvature
from .noise_energy import NoiseEnergy
from .nonlinear_power_curvature import NonlinearPowerCurvature
from .nonlinear_power_energy import NonlinearPowerEnergy
from .nonlinear_wiener_filter_energy import NonlinearWienerFilterEnergy
from .nonlinearities import Exponential, Linear, Tanh, PositiveTanh
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from ..operators.inversion_enabler import InversionEnabler
from .response_operators import LinearizedPowerResponse
def NonlinearPowerCurvature(tau, ht, Instrument, nonlinearity, Distributor, N,
T, xi_sample_list, inverter):
result = None
for xi_sample in xi_sample_list:
LinearizedResponse = LinearizedPowerResponse(
Instrument, nonlinearity, ht, Distributor, tau, xi_sample)
op = LinearizedResponse.adjoint*N.inverse*LinearizedResponse
result = op if result is None else result + op
result = result*(1./len(xi_sample_list)) + T
return InversionEnabler(result, inverter)
......@@ -19,9 +19,16 @@
from .. import exp
from ..minimization.energy import Energy
from ..operators.smoothness_operator import SmoothnessOperator
from ..operators.inversion_enabler import InversionEnabler
from ..utilities import memo
from .nonlinear_power_curvature import NonlinearPowerCurvature
from .response_operators import LinearizedPowerResponse
def _LinearizedPowerResponse(Instrument, nonlinearity, ht, Distributor, tau,
xi):
power = exp(0.5*tau)
position = ht(Distributor(power)*xi)
linearization = nonlinearity.derivative(position)
return 0.5*Instrument*linearization*ht*xi*Distributor*power
class NonlinearPowerEnergy(Energy):
......@@ -80,8 +87,8 @@ class NonlinearPowerEnergy(Energy):
self._gradient = None
for xi_sample in self.xi_sample_list:
map_s = ht(A*xi_sample)
LinR = LinearizedPowerResponse(Instrument, nonlinearity, ht,
Distributor, position, xi_sample)
LinR = _LinearizedPowerResponse(Instrument, nonlinearity, ht,
Distributor, position, xi_sample)
residual = d - Instrument(nonlinearity(map_s))
tmp = N.inverse_times(residual)
......@@ -121,7 +128,12 @@ class NonlinearPowerEnergy(Energy):
@property
@memo
def curvature(self):
return NonlinearPowerCurvature(
self.position, self.ht, self.Instrument, self.nonlinearity,
self.Distributor, self.N, self.T, self.xi_sample_list,
self.inverter)
result = None
for xi_sample in self.xi_sample_list:
LinearizedResponse = _LinearizedPowerResponse(
self.Instrument, self.nonlinearity, self.ht, self.Distributor,
self.position, xi_sample)
op = LinearizedResponse.adjoint*self.N.inverse*LinearizedResponse
result = op if result is None else result + op
result = result*(1./len(self.xi_sample_list)) + self.T
return InversionEnabler(result, self.inverter)
......@@ -19,7 +19,6 @@
from .wiener_filter_curvature import WienerFilterCurvature
from ..utilities import memo
from ..minimization.energy import Energy
from .response_operators import LinearizedSignalResponse
class NonlinearWienerFilterEnergy(Energy):
......@@ -40,8 +39,7 @@ class NonlinearWienerFilterEnergy(Energy):
t1 = S.inverse_times(position)
t2 = N.inverse_times(residual)
self._value = 0.5 * (position.vdot(t1) + residual.vdot(t2)).real
self.R = LinearizedSignalResponse(Instrument, nonlinearity, ht, power,
m)
self.R = Instrument * nonlinearity.derivative(m) * ht * power
self._gradient = (t1 - self.R.adjoint_times(t2)).lock()
def at(self, position):
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from ..field import exp
def LinearizedSignalResponse(Instrument, nonlinearity, ht, power, m):
return Instrument * nonlinearity.derivative(m) * ht * power
def LinearizedPowerResponse(Instrument, nonlinearity, ht, Distributor, tau,
xi):
power = exp(0.5*tau)
position = ht(Distributor(power)*xi)
linearization = nonlinearity.derivative(position)
return 0.5*Instrument*linearization*ht*xi*Distributor*power
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment