Commit 725e5b55 authored by Martin Reinecke's avatar Martin Reinecke

cleanup

parent b7881dae
......@@ -47,8 +47,8 @@ class Linearization(object):
def __neg__(self):
return Linearization(
-self._val, self._jac*(-1),
None if self._metric is None else self._metric*(-1))
-self._val, -self._jac,
None if self._metric is None else -self._metric)
def __add__(self, other):
if isinstance(other, Linearization):
......@@ -74,28 +74,19 @@ class Linearization(object):
def __mul__(self, other):
from .sugar import makeOp
if isinstance(other, Linearization):
d1 = makeOp(self._val)
d2 = makeOp(other._val)
return Linearization(
self._val*other._val,
d2(self._jac) + d1(other._jac))
if isinstance(other, (int, float, complex)):
# if other == 0:
# return ...
met = None if self._metric is None else self._metric(other)
return Linearization(self._val*other, self._jac(other), met)
makeOp(other._val)(self._jac) + makeOp(self._val)(other._jac))
if np.isscalar(other):
if other == 1:
return self
met = None if self._metric is None else self._metric.scale(other)
return Linearization(self._val*other, self._jac.scale(other), met)
if isinstance(other, (Field, MultiField)):
d2 = makeOp(other)
return Linearization(self._val*other, d2(self._jac))
raise TypeError
return Linearization(self._val*other, makeOp(other)(self._jac))
def __rmul__(self, other):
from .sugar import makeOp
if isinstance(other, (int, float, complex)):
return Linearization(self._val*other, self._jac(other))
if isinstance(other, (Field, MultiField)):
d1 = makeOp(other)
return Linearization(self._val*other, d1(self._jac))
return self.__mul__(other)
def vdot(self, other):
from .domain_tuple import DomainTuple
......
......@@ -104,75 +104,38 @@ class LinearOperator(Operator):
the adjoint of this operator."""
return self._flip_modes(self.ADJOINT_BIT)
@staticmethod
def _toOperator(thing, dom):
from .scaling_operator import ScalingOperator
if isinstance(thing, LinearOperator):
return thing
if np.isscalar(thing):
return ScalingOperator(thing, dom)
return NotImplemented
def __mul__(self, other):
from .chain_operator import ChainOperator
if not np.isscalar(other):
return Operator.__mul__(self, other)
if other == 1.:
return self
from .scaling_operator import ScalingOperator
other = ScalingOperator(other, self.domain)
return ChainOperator.make([self, other])
def __rmul__(self, other):
from .chain_operator import ChainOperator
if not np.isscalar(other):
return Operator.__rmul__(self, other)
if other == 1.:
return self
from .scaling_operator import ScalingOperator
other = ScalingOperator(other, self.target)
return ChainOperator.make([other, self])
def __matmul__(self, other):
if np.isscalar(other) and other == 1.:
return self
other2 = self._toOperator(other, self.domain)
if other2 == NotImplemented:
return Operator.__matmul__(self, other)
from .chain_operator import ChainOperator
return ChainOperator.make([self, other2])
if isinstance(other, LinearOperator):
from .chain_operator import ChainOperator
return ChainOperator.make([self, other])
return Operator.__matmul__(self, other)
def __rmatmul__(self, other):
if np.isscalar(other) and other == 1.:
return self
other2 = self._toOperator(other, self.target)
if other2 == NotImplemented:
if isinstance(other, LinearOperator):
from .chain_operator import ChainOperator
return Operator.__rmatmul__(self, other)
from .chain_operator import ChainOperator
return ChainOperator.make([other2, self])
return ChainOperator.make([other, self])
return Operator.__rmatmul__(self, other)
def __add__(self, other):
from .sum_operator import SumOperator
if np.isscalar(other) and other == 0.:
return self
other = self._toOperator(other, self.domain)
return SumOperator.make([self, other], [False, False])
if isinstance(other, LinearOperator):
from .sum_operator import SumOperator
return SumOperator.make([self, other], [False, False])
return Operator.__add__(self, other)
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
from .sum_operator import SumOperator
if np.isscalar(other) and other == 0.:
return self
other = self._toOperator(other, self.domain)
return SumOperator.make([self, other], [False, True])
if isinstance(other, LinearOperator):
from .sum_operator import SumOperator
return SumOperator.make([self, other], [False, True])
return Operator.__sub__(self, other)
def __rsub__(self, other):
from .sum_operator import SumOperator
other = self._toOperator(other, self.domain)
return SumOperator.make([other, self], [False, True])
if isinstance(other, LinearOperator):
from .sum_operator import SumOperator
return SumOperator.make([other, self], [False, True])
return Operator.__rsub__(self, other)
@property
def capability(self):
......
......@@ -24,6 +24,19 @@ class Operator(NiftyMetaBase()):
The domain on which the Operator's output Field lives."""
raise NotImplementedError
def scale(self, factor):
if factor == 1:
return self
from .scaling_operator import ScalingOperator
return ScalingOperator(factor, self.target)(self)
def conjugate(self):
from .simple_linear_operators import ConjugationOperator
return ConjugationOperator(self.target)(self)
def __neg__(self):
return self.scale(-1)
def __matmul__(self, x):
if not isinstance(x, Operator):
return NotImplemented
......
......@@ -54,4 +54,4 @@ def SmoothnessOperator(domain, strength=1., logarithmic=True, space=None):
if strength == 0.:
return ScalingOperator(0., domain)
laplace = LaplaceOperator(domain, logarithmic=logarithmic, space=space)
return (strength**2)*laplace.adjoint(laplace)
return laplace.adjoint(laplace).scale(strength**2)
......@@ -68,7 +68,7 @@ class Model_Tests(unittest.TestCase):
model = ift.FieldAdapter(dom, "s1")+ift.FieldAdapter(dom, "s2")
pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos)
model = ift.FieldAdapter(dom, "s1")*3.
model = ift.FieldAdapter(dom, "s1").scale(3.)
pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos)
model = ift.ScalingOperator(2.456, space)(
......
......@@ -64,7 +64,7 @@ class ComposedOperator_Tests(unittest.TestCase):
@expand(product(spaces))
def test_sum(self, space):
op1 = ift.makeOp(ift.Field.full(space, 2.))
op2 = 3.
op2 = ift.ScalingOperator(3., space)
full_op = op1 + op2 - (op2 - op1) + op1 + op1 + op2
x = ift.Field.full(space, 1.)
res = full_op(x)
......@@ -74,7 +74,7 @@ class ComposedOperator_Tests(unittest.TestCase):
@expand(product(spaces))
def test_chain(self, space):
op1 = ift.makeOp(ift.Field.full(space, 2.))
op2 = 3.
op2 = ift.ScalingOperator(3., space)
full_op = op1(op2)(op2)(op1)(op1)(op1)(op2)
x = ift.Field.full(space, 1.)
res = full_op(x)
......@@ -84,7 +84,7 @@ class ComposedOperator_Tests(unittest.TestCase):
@expand(product(spaces))
def test_mix(self, space):
op1 = ift.makeOp(ift.Field.full(space, 2.))
op2 = 3.
op2 = ift.ScalingOperator(3., space)
full_op = op1(op2+op2)(op1)(op1) - op1(op2)
x = ift.Field.full(space, 1.)
res = full_op(x)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment