Commit 1e157696 authored by Martin Reinecke's avatar Martin Reinecke

operator chaining

parent e288f5be
......@@ -61,7 +61,7 @@ if __name__ == '__main__':
# Generate mock data
d_space = R.target[0]
p = lambda inp: R(sky(inp))
p = R.chain(sky)
mock_position = ift.from_random('normal', harmonic_space)
pp = p(mock_position)
data = np.random.binomial(1, pp.to_global_data().astype(np.float64))
......
......@@ -79,7 +79,7 @@ if __name__ == '__main__':
# Generate mock data
d_space = R.target[0]
lamb = lambda inp: R(sky(inp))
lamb = R.chain(sky)
mock_position = ift.from_random('normal', domain)
#ift.extra.check_value_gradient_consistency2(lamb, mock_position)
#testl = GaussianEnergy2(None, M)
......
......@@ -53,7 +53,7 @@ if __name__ == '__main__':
R = ift.LOSResponse(position_space, starts=LOS_starts,
ends=LOS_ends)
# build signal response model and model likelihood
signal_response = lambda inp: R(signal(inp))
signal_response = R.chain(signal)
# specify noise
data_space = R.target
noise = .001
......@@ -65,7 +65,7 @@ if __name__ == '__main__':
data = signal_response(MOCK_POSITION) + N.draw_sample()
# set up model likelihood
likelihood = lambda inp: ift.GaussianEnergy(mean=data, covariance=N)(signal_response(inp))
likelihood = ift.GaussianEnergy(mean=data, covariance=N).chain(signal_response)
# set up minimization and inversion schemes
ic_cg = ift.GradientNormController(iteration_limit=10)
......
......@@ -97,7 +97,7 @@ d = ift.from_global_data(d_space, y)
N = ift.DiagonalOperator(ift.from_global_data(d_space, var))
IC = ift.GradientNormController(tol_abs_gradnorm=1e-8)
likelihood = lambda inp: ift.GaussianEnergy(d, N)(R(inp))
likelihood = ift.GaussianEnergy(d, N).chain(R)
H = ift.Hamiltonian(likelihood, IC)
H = ift.EnergyAdapter(params, H)
H = H.make_invertible(IC)
......
......@@ -16,8 +16,6 @@ from .domains.log_rg_space import LogRGSpace
from .domain_tuple import DomainTuple
from .field import Field
from .nonlinearities import Exponential, Linear, PositiveTanh, Tanh
from .operators.central_zero_padder import CentralZeroPadder
from .operators.diagonal_operator import DiagonalOperator
from .operators.dof_distributor import DOFDistributor
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from .compat import *
from .sugar import exp, full, tanh
class Linear(object):
def __call__(self, x):
return x
def derivative(self, x):
return full(x.domain, 1.)
def hessian(self, x):
return full(x.domain, 0.)
class Exponential(object):
def __call__(self, x):
return exp(x)
def derivative(self, x):
return exp(x)
def hessian(self, x):
return exp(x)
class Tanh(object):
def __call__(self, x):
return tanh(x)
def derivative(self, x):
return (1. - tanh(x)**2)
def hessian(self, x):
return - 2. * tanh(x) * (1. - tanh(x)**2)
class PositiveTanh(object):
def __call__(self, x):
return 0.5 * tanh(x) + 0.5
def derivative(self, x):
return 0.5 * (1. - tanh(x)**2)
def hessian(self, x):
return - tanh(x) * (1. - tanh(x)**2)
......@@ -9,6 +9,13 @@ class Operator(NiftyMetaBase()):
domain, and can also provide the Jacobian.
"""
def chain(self, x):
if not callable(x):
raise TypeError("callable needed")
ops1 = self._ops if isinstance(self, OpChain) else (self,)
ops2 = x._ops if isinstance(x, OpChain) else (x,)
return OpChain(ops1+ops2)
def __call__(self, x):
"""Returns transformed x
......@@ -23,3 +30,13 @@ class Operator(NiftyMetaBase()):
output
"""
raise NotImplementedError
class OpChain(Operator):
def __init__(self, ops):
self._ops = tuple(ops)
def __call__(self, x):
for op in reversed(self._ops):
x = op(x)
return x
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment