Commit 0ce790cb authored by Martin Reinecke's avatar Martin Reinecke

tmp

parent c3ed466f
......@@ -19,15 +19,14 @@
from __future__ import absolute_import, division, print_function
from ..compat import *
from ..operators.operator import Operator
from ..operators.operator import EnergyOperator
from ..operators.sandwich_operator import SandwichOperator
from ..sugar import makeOp
from ..linearization import Linearization
class BernoulliEnergy(Operator):
class BernoulliEnergy(EnergyOperator):
def __init__(self, p, d):
super(BernoulliEnergy, self).__init__()
self._p = p
self._d = d
......@@ -35,10 +34,6 @@ class BernoulliEnergy(Operator):
def domain(self):
return self._p.domain
@property
def target(self):
return DomainTuple.scalar_domain()
def apply(self, x):
x = self._p(x)
v = x.log().vdot(-self._d) - (1.-x).log().vdot(1.-self._d)
......
......@@ -19,13 +19,13 @@
from __future__ import absolute_import, division, print_function
from ..compat import *
from ..operators.operator import Operator
from ..operators.operator import EnergyOperator
from ..operators.sandwich_operator import SandwichOperator
from ..domain_tuple import DomainTuple
from ..linearization import Linearization
class GaussianEnergy(Operator):
class GaussianEnergy(EnergyOperator):
def __init__(self, mean=None, covariance=None, domain=None):
super(GaussianEnergy, self).__init__()
self._domain = None
......@@ -51,10 +51,6 @@ class GaussianEnergy(Operator):
def domain(self):
return self._domain
@property
def target(self):
return DomainTuple.scalar_domain()
def apply(self, x):
residual = x if self._mean is None else x-self._mean
icovres = residual if self._icov is None else self._icov(residual)
......
......@@ -21,26 +21,20 @@ from __future__ import absolute_import, division, print_function
from numpy import inf, isnan
from ..compat import *
from ..operators.operator import Operator
from ..operators.operator import EnergyOperator
from ..operators.sandwich_operator import SandwichOperator
from ..sugar import makeOp
from ..linearization import Linearization
class PoissonianEnergy(Operator):
class PoissonianEnergy(EnergyOperator):
def __init__(self, op, d):
super(PoissonianEnergy, self).__init__()
self._op = op
self._d = d
self._op, self._d = op, d
@property
def domain(self):
return self._op.domain
@property
def target(self):
return DomainTuple.scalar_domain()
def apply(self, x):
x = self._op(x)
res = x.sum() - x.log().vdot(self._d)
......
......@@ -6,6 +6,7 @@ from .compat import *
from .field import Field
from .multi.multi_field import MultiField
from .sugar import makeOp
from .domain_tuple import DomainTuple
class Linearization(object):
......@@ -109,7 +110,6 @@ class Linearization(object):
VdotOperator(other._val)(self._jac))
def sum(self):
from .domain_tuple import DomainTuple
from .operators.vdot_operator import SumReductionOperator
from .sugar import full
return Linearization(
......
from __future__ import absolute_import, division, print_function
from ..compat import *
from ..utilities import NiftyMetaBase
from ..utilities import NiftyMetaBase, my_product
from ..domain_tuple import DomainTuple
class Operator(NiftyMetaBase()):
......@@ -51,6 +52,13 @@ for f in ["sqrt", "exp", "log", "tanh", "positive_tanh"]:
setattr(Operator, f, func(f))
class EnergyOperator(Operator):
_target = DomainTuple.scalar_domain()
@property
def target(self):
return EnergyOperator._target
class _FunctionApplier(Operator):
def __init__(self, domain, funcname):
from ..sugar import makeDomain
......@@ -123,7 +131,6 @@ class _OpProd(_CombinedOperator):
return self._ops[0].target
def apply(self, x):
from ..utilities import my_product
return my_product(map(lambda op: op(x), self._ops))
......@@ -145,41 +152,31 @@ class _OpSum(_CombinedOperator):
raise NotImplementedError
class SquaredNormOperator(Operator):
class SquaredNormOperator(EnergyOperator):
def __init__(self, domain):
super(SquaredNormOperator, self).__init__()
self._domain = domain
self._target = DomainTuple.scalar_domain()
@property
def domain(self):
return self._domain
@property
def target(self):
return self._target
def __call__(self, x):
def apply(self, x):
return Field(self._target, x.vdot(x))
class QuadraticFormOperator(Operator):
class QuadraticFormOperator(EnergyOperator):
def __init__(self, op):
from .endomorphic_operator import EndomorphicOperator
super(QuadraticFormOperator, self).__init__()
if not isinstance(op, EndomorphicOperator):
raise TypeError("op must be an EndomorphicOperator")
self._op = op
self._target = DomainTuple.scalar_domain()
@property
def domain(self):
return self._op.domain
@property
def target(self):
return self._target
def apply(self, x):
if isinstance(x, Linearization):
jac = self._op(x)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment