Commit 6a602768 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

reordering

parent af6b35ee
...@@ -19,18 +19,6 @@ ...@@ -19,18 +19,6 @@
import nifty5 as ift import nifty5 as ift
import numpy as np import numpy as np
def myexp(lin):
if isinstance(lin, ift.Linearization):
tmp = ift.exp(lin.val)
return ift.Linearization(tmp, ift.makeOp(tmp)*lin.jac)
return ift.exp(lin)
def mylog(lin):
if isinstance(lin, ift.Linearization):
tmp = ift.log(lin.val)
return ift.Linearization(tmp, ift.makeOp(1./lin.val)*lin.jac)
return ift.log(lin)
class GaussianEnergy2(ift.Operator): class GaussianEnergy2(ift.Operator):
def __init__(self, mean=None, covariance=None): def __init__(self, mean=None, covariance=None):
super(GaussianEnergy2, self).__init__() super(GaussianEnergy2, self).__init__()
...@@ -42,7 +30,7 @@ class GaussianEnergy2(ift.Operator): ...@@ -42,7 +30,7 @@ class GaussianEnergy2(ift.Operator):
icovres = residual if self._icov is None else self._icov(residual) icovres = residual if self._icov is None else self._icov(residual)
res = .5 * (residual*icovres).sum() res = .5 * (residual*icovres).sum()
metric = ift.SandwichOperator.make(x.jac, self._icov) metric = ift.SandwichOperator.make(x.jac, self._icov)
return ift.Linearization(res.val, res.jac, metric) return res.add_metric(metric)
class PoissonianEnergy2(ift.Operator): class PoissonianEnergy2(ift.Operator):
def __init__(self, op, d): def __init__(self, op, d):
...@@ -52,9 +40,9 @@ class PoissonianEnergy2(ift.Operator): ...@@ -52,9 +40,9 @@ class PoissonianEnergy2(ift.Operator):
def __call__(self, x): def __call__(self, x):
x = self._op(x) x = self._op(x)
res = (x - self._d*mylog(x)).sum() res = (x - self._d*x.log()).sum()
metric = ift.SandwichOperator.make(x.jac, ift.makeOp(1./x.val)) metric = ift.SandwichOperator.make(x.jac, ift.makeOp(1./x.val))
return ift.Linearization(res.val, res.jac, metric) return res.add_metric(metric)
class MyHamiltonian(ift.Operator): class MyHamiltonian(ift.Operator):
def __init__(self, lh): def __init__(self, lh):
...@@ -138,7 +126,7 @@ if __name__ == '__main__': ...@@ -138,7 +126,7 @@ if __name__ == '__main__':
A = pd(a) A = pd(a)
# Set up a sky model # Set up a sky model
sky = lambda inp: myexp(HT(inp*A)) sky = lambda inp: (HT(A*inp)).exp()
M = ift.DiagonalOperator(exposure) M = ift.DiagonalOperator(exposure)
GR = ift.GeometryRemover(position_space) GR = ift.GeometryRemover(position_space)
......
...@@ -104,7 +104,8 @@ from .multi.block_diagonal_operator import BlockDiagonalOperator ...@@ -104,7 +104,8 @@ from .multi.block_diagonal_operator import BlockDiagonalOperator
from .energies.kl import SampledKullbachLeiblerDivergence from .energies.kl import SampledKullbachLeiblerDivergence
from .energies.hamiltonian import Hamiltonian from .energies.hamiltonian import Hamiltonian
from.operator import Linearization, Operator from .operator import Operator
from .linearization import Linearization
# We deliberately don't set __all__ here, because we don't want people to do a # We deliberately don't set __all__ here, because we don't want people to do a
# "from nifty5 import *"; that would swamp the global namespace. # "from nifty5 import *"; that would swamp the global namespace.
...@@ -651,3 +651,11 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__", ...@@ -651,3 +651,11 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
"In-place operations are deliberately not supported") "In-place operations are deliberately not supported")
return func2 return func2
setattr(Field, op, func(op)) setattr(Field, op, func(op))
for f in ["sqrt", "exp", "log", "tanh"]:
def func(f):
def func2(self):
fu = getattr(dobj, f)
return Field(domain=self._domain, val=fu(self.val))
return func2
setattr(Field, f, func(f))
from __future__ import absolute_import, division, print_function
import numpy as np
from .compat import *
from .field import Field
from .multi.multi_field import MultiField
from .sugar import makeOp
class Linearization(object):
def __init__(self, val, jac, metric=None):
self._val = val
self._jac = jac
self._metric = metric
@property
def domain(self):
return self._jac.domain
@property
def target(self):
return self._jac.target
@property
def val(self):
return self._val
@property
def jac(self):
return self._jac
@property
def metric(self):
return self._metric
def __neg__(self):
return Linearization(-self._val, self._jac*(-1),
None if self._metric is None else self._metric*(-1))
def __add__(self, other):
if isinstance(other, Linearization):
from .operators.relaxed_sum_operator import RelaxedSumOperator
met = None
if self._metric is not None and other._metric is not None:
met = RelaxedSumOperator((self._metric, other._metric))
return Linearization(
self._val.unite(other._val),
RelaxedSumOperator((self._jac, other._jac)), met)
if isinstance(other, (int, float, complex, Field, MultiField)):
return Linearization(self._val+other, self._jac, self._metric)
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
return self.__add__(-other)
def __rsub__(self, other):
return (-self).__add__(other)
def __mul__(self, other):
from .sugar import makeOp
if isinstance(other, Linearization):
d1 = makeOp(self._val)
d2 = makeOp(other._val)
return Linearization(self._val*other._val,
d2*self._jac + d1*other._jac)
if isinstance(other, (int, float, complex)):
# if other == 0:
# return ...
return Linearization(self._val*other, self._jac*other)
if isinstance(other, (Field, MultiField)):
d2 = makeOp(other)
return Linearization(self._val*other, self._jac*d2)
raise TypeError
def __rmul__(self, other):
from .sugar import makeOp
if isinstance(other, (int, float, complex)):
return Linearization(self._val*other, self._jac*other)
if isinstance(other, (Field, MultiField)):
d1 = makeOp(other)
return Linearization(self._val*other, d1*self._jac)
def sum(self):
from .sugar import full
from .operators.vdot_operator import VdotOperator
return Linearization(full((),self._val.sum()),
VdotOperator(full(self._jac.target,1))*self._jac)
def exp(self):
tmp = self._val.exp()
return Linearization(tmp, makeOp(tmp)*self._jac)
def log(self):
tmp = self._val.log()
return Linearization(tmp, makeOp(1./self._val)*self._jac)
def add_metric(self, metric):
return Linearization(self._val, self._jac, metric)
@staticmethod
def make_var(field):
from .operators.scaling_operator import ScalingOperator
return Linearization(field, ScalingOperator(1., field.domain))
@staticmethod
def make_const(field):
from .operators.null_operator import NullOperator
return Linearization(field, NullOperator({}, field.domain))
...@@ -263,3 +263,13 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__", ...@@ -263,3 +263,13 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
"In-place operations are deliberately not supported") "In-place operations are deliberately not supported")
return func2 return func2
setattr(MultiField, op, func(op)) setattr(MultiField, op, func(op))
for f in ["sqrt", "exp", "log", "tanh"]:
def func(f):
def func2(self):
fu = getattr(dobj, f)
return MultiField(self.domain,
tuple(func2(val) for val in self.values()))
return func2
setattr(MultiField, f, func(f))
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import abc
import numpy as np
from .compat import * from .compat import *
from .utilities import NiftyMetaBase from .utilities import NiftyMetaBase
from .field import Field
from .multi.multi_field import MultiField
class Linearization(object):
def __init__(self, val, jac, metric=None):
self._val = val
self._jac = jac
self._metric = metric
@property
def domain(self):
return self._jac.domain
@property
def target(self):
return self._jac.target
@property
def val(self):
return self._val
@property
def jac(self):
return self._jac
@property
def metric(self):
return self._metric
def __neg__(self):
return Linearization(-self._val, self._jac*(-1),
None if self._metric is None else self._metric*(-1))
def __add__(self, other):
if isinstance(other, Linearization):
from .operators.relaxed_sum_operator import RelaxedSumOperator
met = None
if self._metric is not None and other._metric is not None:
met = RelaxedSumOperator((self._metric, other._metric))
return Linearization(
self._val.unite(other._val),
RelaxedSumOperator((self._jac, other._jac)), met)
if isinstance(other, (int, float, complex, Field, MultiField)):
return Linearization(self._val+other, self._jac, self._metric)
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
return self.__add__(-other)
def __rsub__(self, other):
return (-self).__add__(other)
def __mul__(self, other):
from .sugar import makeOp
if isinstance(other, Linearization):
d1 = makeOp(self._val)
d2 = makeOp(other._val)
return Linearization(self._val*other._val,
d2*self._jac + d1*other._jac)
if isinstance(other, (int, float, complex)):
# if other == 0:
# return ...
return Linearization(self._val*other, self._jac*other)
if isinstance(other, (Field, MultiField)):
d2 = makeOp(other)
return Linearization(self._val*other, self._jac*d2)
raise TypeError
def __rmul__(self, other):
from .sugar import makeOp
if isinstance(other, (int, float, complex)):
return Linearization(self._val*other, self._jac*other)
if isinstance(other, (Field, MultiField)):
d1 = makeOp(other)
return Linearization(self._val*other, d1*self._jac)
def sum(self):
from .sugar import full
from .operators.vdot_operator import VdotOperator
return Linearization(full((),self._val.sum()),
VdotOperator(full(self._jac.target,1))*self._jac)
@staticmethod
def make_var(field):
from .operators.scaling_operator import ScalingOperator
return Linearization(field, ScalingOperator(1., field.domain))
@staticmethod
def make_const(field):
from .operators.null_operator import NullOperator
return Linearization(field, NullOperator({}, field.domain))
class Operator(NiftyMetaBase()): class Operator(NiftyMetaBase()):
......
...@@ -23,7 +23,7 @@ import abc ...@@ -23,7 +23,7 @@ import abc
import numpy as np import numpy as np
from ..compat import * from ..compat import *
from ..operator import Operator, Linearization from ..operator import Operator
class LinearOperator(Operator): class LinearOperator(Operator):
...@@ -205,6 +205,7 @@ class LinearOperator(Operator): ...@@ -205,6 +205,7 @@ class LinearOperator(Operator):
"""Same as :meth:`times`""" """Same as :meth:`times`"""
from ..models.model import Model from ..models.model import Model
from ..models.linear_model import LinearModel from ..models.linear_model import LinearModel
from ..linearization import Linearization
if isinstance(x, Linearization): if isinstance(x, Linearization):
return Linearization(self(x._val), self*x._jac) return Linearization(self(x._val), self*x._jac)
if isinstance(x, Model): if isinstance(x, Model):
......
...@@ -34,6 +34,7 @@ from .multi.multi_field import MultiField ...@@ -34,6 +34,7 @@ from .multi.multi_field import MultiField
from .operators.diagonal_operator import DiagonalOperator from .operators.diagonal_operator import DiagonalOperator
from .operators.power_distributor import PowerDistributor from .operators.power_distributor import PowerDistributor
__all__ = ['PS_field', 'power_analyze', 'create_power_operator', __all__ = ['PS_field', 'power_analyze', 'create_power_operator',
'create_harmonic_smoothing_operator', 'from_random', 'create_harmonic_smoothing_operator', 'from_random',
'full', 'from_global_data', 'from_local_data', 'full', 'from_global_data', 'from_local_data',
...@@ -259,12 +260,9 @@ _current_module = sys.modules[__name__] ...@@ -259,12 +260,9 @@ _current_module = sys.modules[__name__]
for f in ["sqrt", "exp", "log", "tanh", "conjugate"]: for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
def func(f): def func(f):
def func2(x): def func2(x):
if isinstance(x, MultiField): from .linearization import Linearization
return MultiField(x.domain, if isinstance(x, (Field, MultiField, Linearization)):
tuple(func2(val) for val in x.values())) return getattr(x, f)()
elif isinstance(x, Field):
fu = getattr(dobj, f)
return Field(domain=x._domain, val=fu(x.val))
else: else:
return getattr(np, f)(x) return getattr(np, f)(x)
return func2 return func2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment