Commit 6a602768 authored by Martin Reinecke's avatar Martin Reinecke

reordering

parent af6b35ee
......@@ -19,18 +19,6 @@
import nifty5 as ift
import numpy as np
def myexp(lin):
if isinstance(lin, ift.Linearization):
tmp = ift.exp(lin.val)
return ift.Linearization(tmp, ift.makeOp(tmp)*lin.jac)
return ift.exp(lin)
def mylog(lin):
if isinstance(lin, ift.Linearization):
tmp = ift.log(lin.val)
return ift.Linearization(tmp, ift.makeOp(1./lin.val)*lin.jac)
return ift.log(lin)
class GaussianEnergy2(ift.Operator):
def __init__(self, mean=None, covariance=None):
super(GaussianEnergy2, self).__init__()
......@@ -42,7 +30,7 @@ class GaussianEnergy2(ift.Operator):
icovres = residual if self._icov is None else self._icov(residual)
res = .5 * (residual*icovres).sum()
metric = ift.SandwichOperator.make(x.jac, self._icov)
return ift.Linearization(res.val, res.jac, metric)
return res.add_metric(metric)
class PoissonianEnergy2(ift.Operator):
def __init__(self, op, d):
......@@ -52,9 +40,9 @@ class PoissonianEnergy2(ift.Operator):
def __call__(self, x):
x = self._op(x)
res = (x - self._d*mylog(x)).sum()
res = (x - self._d*x.log()).sum()
metric = ift.SandwichOperator.make(x.jac, ift.makeOp(1./x.val))
return ift.Linearization(res.val, res.jac, metric)
return res.add_metric(metric)
class MyHamiltonian(ift.Operator):
def __init__(self, lh):
......@@ -138,7 +126,7 @@ if __name__ == '__main__':
A = pd(a)
# Set up a sky model
sky = lambda inp: myexp(HT(inp*A))
sky = lambda inp: (HT(A*inp)).exp()
M = ift.DiagonalOperator(exposure)
GR = ift.GeometryRemover(position_space)
......
......@@ -104,7 +104,8 @@ from .multi.block_diagonal_operator import BlockDiagonalOperator
from .energies.kl import SampledKullbachLeiblerDivergence
from .energies.hamiltonian import Hamiltonian
from.operator import Linearization, Operator
from .operator import Operator
from .linearization import Linearization
# We deliberately don't set __all__ here, because we don't want people to do a
# "from nifty5 import *"; that would swamp the global namespace.
......@@ -651,3 +651,11 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
"In-place operations are deliberately not supported")
return func2
setattr(Field, op, func(op))
for f in ["sqrt", "exp", "log", "tanh"]:
def func(f):
def func2(self):
fu = getattr(dobj, f)
return Field(domain=self._domain, val=fu(self.val))
return func2
setattr(Field, f, func(f))
from __future__ import absolute_import, division, print_function
import numpy as np
from .compat import *
from .field import Field
from .multi.multi_field import MultiField
from .sugar import makeOp
class Linearization(object):
def __init__(self, val, jac, metric=None):
self._val = val
self._jac = jac
self._metric = metric
@property
def domain(self):
return self._jac.domain
@property
def target(self):
return self._jac.target
@property
def val(self):
return self._val
@property
def jac(self):
return self._jac
@property
def metric(self):
return self._metric
def __neg__(self):
return Linearization(-self._val, self._jac*(-1),
None if self._metric is None else self._metric*(-1))
def __add__(self, other):
if isinstance(other, Linearization):
from .operators.relaxed_sum_operator import RelaxedSumOperator
met = None
if self._metric is not None and other._metric is not None:
met = RelaxedSumOperator((self._metric, other._metric))
return Linearization(
self._val.unite(other._val),
RelaxedSumOperator((self._jac, other._jac)), met)
if isinstance(other, (int, float, complex, Field, MultiField)):
return Linearization(self._val+other, self._jac, self._metric)
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
return self.__add__(-other)
def __rsub__(self, other):
return (-self).__add__(other)
def __mul__(self, other):
from .sugar import makeOp
if isinstance(other, Linearization):
d1 = makeOp(self._val)
d2 = makeOp(other._val)
return Linearization(self._val*other._val,
d2*self._jac + d1*other._jac)
if isinstance(other, (int, float, complex)):
# if other == 0:
# return ...
return Linearization(self._val*other, self._jac*other)
if isinstance(other, (Field, MultiField)):
d2 = makeOp(other)
return Linearization(self._val*other, self._jac*d2)
raise TypeError
def __rmul__(self, other):
from .sugar import makeOp
if isinstance(other, (int, float, complex)):
return Linearization(self._val*other, self._jac*other)
if isinstance(other, (Field, MultiField)):
d1 = makeOp(other)
return Linearization(self._val*other, d1*self._jac)
def sum(self):
from .sugar import full
from .operators.vdot_operator import VdotOperator
return Linearization(full((),self._val.sum()),
VdotOperator(full(self._jac.target,1))*self._jac)
def exp(self):
tmp = self._val.exp()
return Linearization(tmp, makeOp(tmp)*self._jac)
def log(self):
tmp = self._val.log()
return Linearization(tmp, makeOp(1./self._val)*self._jac)
def add_metric(self, metric):
return Linearization(self._val, self._jac, metric)
@staticmethod
def make_var(field):
from .operators.scaling_operator import ScalingOperator
return Linearization(field, ScalingOperator(1., field.domain))
@staticmethod
def make_const(field):
from .operators.null_operator import NullOperator
return Linearization(field, NullOperator({}, field.domain))
......@@ -263,3 +263,13 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
"In-place operations are deliberately not supported")
return func2
setattr(MultiField, op, func(op))
for f in ["sqrt", "exp", "log", "tanh"]:
def func(f):
def func2(self):
fu = getattr(dobj, f)
return MultiField(self.domain,
tuple(func2(val) for val in self.values()))
return func2
setattr(MultiField, f, func(f))
from __future__ import absolute_import, division, print_function
import abc
import numpy as np
from .compat import *
from .utilities import NiftyMetaBase
from .field import Field
from .multi.multi_field import MultiField
class Linearization(object):
def __init__(self, val, jac, metric=None):
self._val = val
self._jac = jac
self._metric = metric
@property
def domain(self):
return self._jac.domain
@property
def target(self):
return self._jac.target
@property
def val(self):
return self._val
@property
def jac(self):
return self._jac
@property
def metric(self):
return self._metric
def __neg__(self):
return Linearization(-self._val, self._jac*(-1),
None if self._metric is None else self._metric*(-1))
def __add__(self, other):
if isinstance(other, Linearization):
from .operators.relaxed_sum_operator import RelaxedSumOperator
met = None
if self._metric is not None and other._metric is not None:
met = RelaxedSumOperator((self._metric, other._metric))
return Linearization(
self._val.unite(other._val),
RelaxedSumOperator((self._jac, other._jac)), met)
if isinstance(other, (int, float, complex, Field, MultiField)):
return Linearization(self._val+other, self._jac, self._metric)
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
return self.__add__(-other)
def __rsub__(self, other):
return (-self).__add__(other)
def __mul__(self, other):
from .sugar import makeOp
if isinstance(other, Linearization):
d1 = makeOp(self._val)
d2 = makeOp(other._val)
return Linearization(self._val*other._val,
d2*self._jac + d1*other._jac)
if isinstance(other, (int, float, complex)):
# if other == 0:
# return ...
return Linearization(self._val*other, self._jac*other)
if isinstance(other, (Field, MultiField)):
d2 = makeOp(other)
return Linearization(self._val*other, self._jac*d2)
raise TypeError
def __rmul__(self, other):
from .sugar import makeOp
if isinstance(other, (int, float, complex)):
return Linearization(self._val*other, self._jac*other)
if isinstance(other, (Field, MultiField)):
d1 = makeOp(other)
return Linearization(self._val*other, d1*self._jac)
def sum(self):
from .sugar import full
from .operators.vdot_operator import VdotOperator
return Linearization(full((),self._val.sum()),
VdotOperator(full(self._jac.target,1))*self._jac)
@staticmethod
def make_var(field):
from .operators.scaling_operator import ScalingOperator
return Linearization(field, ScalingOperator(1., field.domain))
@staticmethod
def make_const(field):
from .operators.null_operator import NullOperator
return Linearization(field, NullOperator({}, field.domain))
class Operator(NiftyMetaBase()):
......
......@@ -23,7 +23,7 @@ import abc
import numpy as np
from ..compat import *
from ..operator import Operator, Linearization
from ..operator import Operator
class LinearOperator(Operator):
......@@ -205,6 +205,7 @@ class LinearOperator(Operator):
"""Same as :meth:`times`"""
from ..models.model import Model
from ..models.linear_model import LinearModel
from ..linearization import Linearization
if isinstance(x, Linearization):
return Linearization(self(x._val), self*x._jac)
if isinstance(x, Model):
......
......@@ -34,6 +34,7 @@ from .multi.multi_field import MultiField
from .operators.diagonal_operator import DiagonalOperator
from .operators.power_distributor import PowerDistributor
__all__ = ['PS_field', 'power_analyze', 'create_power_operator',
'create_harmonic_smoothing_operator', 'from_random',
'full', 'from_global_data', 'from_local_data',
......@@ -259,12 +260,9 @@ _current_module = sys.modules[__name__]
for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
def func(f):
def func2(x):
if isinstance(x, MultiField):
return MultiField(x.domain,
tuple(func2(val) for val in x.values()))
elif isinstance(x, Field):
fu = getattr(dobj, f)
return Field(domain=x._domain, val=fu(x.val))
from .linearization import Linearization
if isinstance(x, (Field, MultiField, Linearization)):
return getattr(x, f)()
else:
return getattr(np, f)(x)
return func2
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment