Commit 6a602768 by Martin Reinecke

### reordering

parent af6b35ee
 ... @@ -19,18 +19,6 @@ ... @@ -19,18 +19,6 @@ import nifty5 as ift import nifty5 as ift import numpy as np import numpy as np def myexp(lin): if isinstance(lin, ift.Linearization): tmp = ift.exp(lin.val) return ift.Linearization(tmp, ift.makeOp(tmp)*lin.jac) return ift.exp(lin) def mylog(lin): if isinstance(lin, ift.Linearization): tmp = ift.log(lin.val) return ift.Linearization(tmp, ift.makeOp(1./lin.val)*lin.jac) return ift.log(lin) class GaussianEnergy2(ift.Operator): class GaussianEnergy2(ift.Operator): def __init__(self, mean=None, covariance=None): def __init__(self, mean=None, covariance=None): super(GaussianEnergy2, self).__init__() super(GaussianEnergy2, self).__init__() ... @@ -42,7 +30,7 @@ class GaussianEnergy2(ift.Operator): ... @@ -42,7 +30,7 @@ class GaussianEnergy2(ift.Operator): icovres = residual if self._icov is None else self._icov(residual) icovres = residual if self._icov is None else self._icov(residual) res = .5 * (residual*icovres).sum() res = .5 * (residual*icovres).sum() metric = ift.SandwichOperator.make(x.jac, self._icov) metric = ift.SandwichOperator.make(x.jac, self._icov) return ift.Linearization(res.val, res.jac, metric) return res.add_metric(metric) class PoissonianEnergy2(ift.Operator): class PoissonianEnergy2(ift.Operator): def __init__(self, op, d): def __init__(self, op, d): ... @@ -52,9 +40,9 @@ class PoissonianEnergy2(ift.Operator): ... @@ -52,9 +40,9 @@ class PoissonianEnergy2(ift.Operator): def __call__(self, x): def __call__(self, x): x = self._op(x) x = self._op(x) res = (x - self._d*mylog(x)).sum() res = (x - self._d*x.log()).sum() metric = ift.SandwichOperator.make(x.jac, ift.makeOp(1./x.val)) metric = ift.SandwichOperator.make(x.jac, ift.makeOp(1./x.val)) return ift.Linearization(res.val, res.jac, metric) return res.add_metric(metric) class MyHamiltonian(ift.Operator): class MyHamiltonian(ift.Operator): def __init__(self, lh): def __init__(self, lh): ... @@ -138,7 +126,7 @@ if __name__ == '__main__': ... @@ -138,7 +126,7 @@ if __name__ == '__main__': A = pd(a) A = pd(a) # Set up a sky model # Set up a sky model sky = lambda inp: myexp(HT(inp*A)) sky = lambda inp: (HT(A*inp)).exp() M = ift.DiagonalOperator(exposure) M = ift.DiagonalOperator(exposure) GR = ift.GeometryRemover(position_space) GR = ift.GeometryRemover(position_space) ... ...
 ... @@ -104,7 +104,8 @@ from .multi.block_diagonal_operator import BlockDiagonalOperator ... @@ -104,7 +104,8 @@ from .multi.block_diagonal_operator import BlockDiagonalOperator from .energies.kl import SampledKullbachLeiblerDivergence from .energies.kl import SampledKullbachLeiblerDivergence from .energies.hamiltonian import Hamiltonian from .energies.hamiltonian import Hamiltonian from.operator import Linearization, Operator from .operator import Operator from .linearization import Linearization # We deliberately don't set __all__ here, because we don't want people to do a # We deliberately don't set __all__ here, because we don't want people to do a # "from nifty5 import *"; that would swamp the global namespace. # "from nifty5 import *"; that would swamp the global namespace.
 ... @@ -651,3 +651,11 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__", ... @@ -651,3 +651,11 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__", "In-place operations are deliberately not supported") "In-place operations are deliberately not supported") return func2 return func2 setattr(Field, op, func(op)) setattr(Field, op, func(op)) for f in ["sqrt", "exp", "log", "tanh"]: def func(f): def func2(self): fu = getattr(dobj, f) return Field(domain=self._domain, val=fu(self.val)) return func2 setattr(Field, f, func(f))
 from __future__ import absolute_import, division, print_function import numpy as np from .compat import * from .field import Field from .multi.multi_field import MultiField from .sugar import makeOp class Linearization(object): def __init__(self, val, jac, metric=None): self._val = val self._jac = jac self._metric = metric @property def domain(self): return self._jac.domain @property def target(self): return self._jac.target @property def val(self): return self._val @property def jac(self): return self._jac @property def metric(self): return self._metric def __neg__(self): return Linearization(-self._val, self._jac*(-1), None if self._metric is None else self._metric*(-1)) def __add__(self, other): if isinstance(other, Linearization): from .operators.relaxed_sum_operator import RelaxedSumOperator met = None if self._metric is not None and other._metric is not None: met = RelaxedSumOperator((self._metric, other._metric)) return Linearization( self._val.unite(other._val), RelaxedSumOperator((self._jac, other._jac)), met) if isinstance(other, (int, float, complex, Field, MultiField)): return Linearization(self._val+other, self._jac, self._metric) def __radd__(self, other): return self.__add__(other) def __sub__(self, other): return self.__add__(-other) def __rsub__(self, other): return (-self).__add__(other) def __mul__(self, other): from .sugar import makeOp if isinstance(other, Linearization): d1 = makeOp(self._val) d2 = makeOp(other._val) return Linearization(self._val*other._val, d2*self._jac + d1*other._jac) if isinstance(other, (int, float, complex)): # if other == 0: # return ... return Linearization(self._val*other, self._jac*other) if isinstance(other, (Field, MultiField)): d2 = makeOp(other) return Linearization(self._val*other, self._jac*d2) raise TypeError def __rmul__(self, other): from .sugar import makeOp if isinstance(other, (int, float, complex)): return Linearization(self._val*other, self._jac*other) if isinstance(other, (Field, MultiField)): d1 = makeOp(other) return Linearization(self._val*other, d1*self._jac) def sum(self): from .sugar import full from .operators.vdot_operator import VdotOperator return Linearization(full((),self._val.sum()), VdotOperator(full(self._jac.target,1))*self._jac) def exp(self): tmp = self._val.exp() return Linearization(tmp, makeOp(tmp)*self._jac) def log(self): tmp = self._val.log() return Linearization(tmp, makeOp(1./self._val)*self._jac) def add_metric(self, metric): return Linearization(self._val, self._jac, metric) @staticmethod def make_var(field): from .operators.scaling_operator import ScalingOperator return Linearization(field, ScalingOperator(1., field.domain)) @staticmethod def make_const(field): from .operators.null_operator import NullOperator return Linearization(field, NullOperator({}, field.domain))
 ... @@ -263,3 +263,13 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__", ... @@ -263,3 +263,13 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__", "In-place operations are deliberately not supported") "In-place operations are deliberately not supported") return func2 return func2 setattr(MultiField, op, func(op)) setattr(MultiField, op, func(op)) for f in ["sqrt", "exp", "log", "tanh"]: def func(f): def func2(self): fu = getattr(dobj, f) return MultiField(self.domain, tuple(func2(val) for val in self.values())) return func2 setattr(MultiField, f, func(f))
 from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function import abc import numpy as np from .compat import * from .compat import * from .utilities import NiftyMetaBase from .utilities import NiftyMetaBase from .field import Field from .multi.multi_field import MultiField class Linearization(object): def __init__(self, val, jac, metric=None): self._val = val self._jac = jac self._metric = metric @property def domain(self): return self._jac.domain @property def target(self): return self._jac.target @property def val(self): return self._val @property def jac(self): return self._jac @property def metric(self): return self._metric def __neg__(self): return Linearization(-self._val, self._jac*(-1), None if self._metric is None else self._metric*(-1)) def __add__(self, other): if isinstance(other, Linearization): from .operators.relaxed_sum_operator import RelaxedSumOperator met = None if self._metric is not None and other._metric is not None: met = RelaxedSumOperator((self._metric, other._metric)) return Linearization( self._val.unite(other._val), RelaxedSumOperator((self._jac, other._jac)), met) if isinstance(other, (int, float, complex, Field, MultiField)): return Linearization(self._val+other, self._jac, self._metric) def __radd__(self, other): return self.__add__(other) def __sub__(self, other): return self.__add__(-other) def __rsub__(self, other): return (-self).__add__(other) def __mul__(self, other): from .sugar import makeOp if isinstance(other, Linearization): d1 = makeOp(self._val) d2 = makeOp(other._val) return Linearization(self._val*other._val, d2*self._jac + d1*other._jac) if isinstance(other, (int, float, complex)): # if other == 0: # return ... return Linearization(self._val*other, self._jac*other) if isinstance(other, (Field, MultiField)): d2 = makeOp(other) return Linearization(self._val*other, self._jac*d2) raise TypeError def __rmul__(self, other): from .sugar import makeOp if isinstance(other, (int, float, complex)): return Linearization(self._val*other, self._jac*other) if isinstance(other, (Field, MultiField)): d1 = makeOp(other) return Linearization(self._val*other, d1*self._jac) def sum(self): from .sugar import full from .operators.vdot_operator import VdotOperator return Linearization(full((),self._val.sum()), VdotOperator(full(self._jac.target,1))*self._jac) @staticmethod def make_var(field): from .operators.scaling_operator import ScalingOperator return Linearization(field, ScalingOperator(1., field.domain)) @staticmethod def make_const(field): from .operators.null_operator import NullOperator return Linearization(field, NullOperator({}, field.domain)) class Operator(NiftyMetaBase()): class Operator(NiftyMetaBase()): ... ...
 ... @@ -23,7 +23,7 @@ import abc ... @@ -23,7 +23,7 @@ import abc import numpy as np import numpy as np from ..compat import * from ..compat import * from ..operator import Operator, Linearization from ..operator import Operator class LinearOperator(Operator): class LinearOperator(Operator): ... @@ -205,6 +205,7 @@ class LinearOperator(Operator): ... @@ -205,6 +205,7 @@ class LinearOperator(Operator): """Same as :meth:`times`""" """Same as :meth:`times`""" from ..models.model import Model from ..models.model import Model from ..models.linear_model import LinearModel from ..models.linear_model import LinearModel from ..linearization import Linearization if isinstance(x, Linearization): if isinstance(x, Linearization): return Linearization(self(x._val), self*x._jac) return Linearization(self(x._val), self*x._jac) if isinstance(x, Model): if isinstance(x, Model): ... ...
 ... @@ -34,6 +34,7 @@ from .multi.multi_field import MultiField ... @@ -34,6 +34,7 @@ from .multi.multi_field import MultiField from .operators.diagonal_operator import DiagonalOperator from .operators.diagonal_operator import DiagonalOperator from .operators.power_distributor import PowerDistributor from .operators.power_distributor import PowerDistributor __all__ = ['PS_field', 'power_analyze', 'create_power_operator', __all__ = ['PS_field', 'power_analyze', 'create_power_operator', 'create_harmonic_smoothing_operator', 'from_random', 'create_harmonic_smoothing_operator', 'from_random', 'full', 'from_global_data', 'from_local_data', 'full', 'from_global_data', 'from_local_data', ... @@ -259,12 +260,9 @@ _current_module = sys.modules[__name__] ... @@ -259,12 +260,9 @@ _current_module = sys.modules[__name__] for f in ["sqrt", "exp", "log", "tanh", "conjugate"]: for f in ["sqrt", "exp", "log", "tanh", "conjugate"]: def func(f): def func(f): def func2(x): def func2(x): if isinstance(x, MultiField): from .linearization import Linearization return MultiField(x.domain, if isinstance(x, (Field, MultiField, Linearization)): tuple(func2(val) for val in x.values())) return getattr(x, f)() elif isinstance(x, Field): fu = getattr(dobj, f) return Field(domain=x._domain, val=fu(x.val)) else: else: return getattr(np, f)(x) return getattr(np, f)(x) return func2 return func2 ... ...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!