Commit 14daede2 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

fewer isinstance checks

parent 1b43db0e
Pipeline #72347 passed with stages
in 19 minutes and 30 seconds
......@@ -17,8 +17,6 @@
import numpy as np
from .field import Field
from .multi_field import MultiField
from .sugar import makeOp
from . import utilities
from .operators.operator import Operator
......@@ -65,7 +63,7 @@ class Linearization(Operator):
return Linearization(val, jac, metric, self._want_metric)
def trivial_jac(self):
return Linearization.make_var(self._val, self._want_metric)
return self.make_var(self._val, self._want_metric)
def prepend_jac(self, jac):
metric = None
......@@ -102,6 +100,7 @@ class Linearization(Operator):
-----
Only available if target is a scalar
"""
from .field import Field
return self._jac.adjoint_times(Field.scalar(1.))
@property
......@@ -136,18 +135,15 @@ class Linearization(Operator):
return self.new(self._val.real, self._jac.real)
def _myadd(self, other, neg):
if isinstance(other, Linearization):
met = None
if self._metric is not None and other._metric is not None:
met = self._metric._myadd(other._metric, neg)
return self.new(
self._val.flexible_addsub(other._val, neg),
self._jac._myadd(other._jac, neg), met)
if isinstance(other, (int, float, complex, Field, MultiField)):
if neg:
return self.new(self._val-other, self._jac, self._metric)
else:
return self.new(self._val+other, self._jac, self._metric)
if np.isscalar(other) or other.jac is None:
return self.new(self._val-other if neg else self._val+other,
self._jac, self._metric)
met = None
if self._metric is not None and other._metric is not None:
met = self._metric._myadd(other._metric, neg)
return self.new(
self.val.flexible_addsub(other.val, neg),
self.jac._myadd(other.jac, neg), met)
def __add__(self, other):
return self._myadd(other, False)
......@@ -162,36 +158,35 @@ class Linearization(Operator):
return (-self).__add__(other)
def __truediv__(self, other):
if isinstance(other, Linearization):
return self.__mul__(other.ptw("reciprocal"))
return self.__mul__(1./other)
if np.isscalar(other):
return self.__mul__(1/other)
return self.__mul__(other.ptw("reciprocal"))
def __rtruediv__(self, other):
return self.ptw("reciprocal").__mul__(other)
def __pow__(self, power):
if not (np.isscalar(power) or isinstance(power, (Field, MultiField))):
if not (np.isscalar(power) or power.jac is None):
return NotImplemented
return self.ptw("power", power)
def __mul__(self, other):
from .sugar import makeOp
if isinstance(other, Linearization):
if self.target != other.target:
raise ValueError("domain mismatch")
return self.new(
self._val*other._val,
(makeOp(other._val)(self._jac))._myadd(
makeOp(self._val)(other._jac), False))
if np.isscalar(other):
if other == 1:
return self
met = None if self._metric is None else self._metric.scale(other)
return self.new(self._val*other, self._jac.scale(other), met)
if isinstance(other, (Field, MultiField)):
from .sugar import makeOp
if other.jac is None:
if self.target != other.domain:
raise ValueError("domain mismatch")
return self.new(self._val*other, makeOp(other)(self._jac))
if self.target != other.target:
raise ValueError("domain mismatch")
return self.new(
self.val*other.val,
(makeOp(other.val)(self.jac))._myadd(
makeOp(self.val)(other.jac), False))
def __rmul__(self, other):
return self.__mul__(other)
......@@ -209,17 +204,16 @@ class Linearization(Operator):
Linearization
the outer product of self and other
"""
from .operators.outer_product_operator import OuterProduct
if isinstance(other, Linearization):
return self.new(
OuterProduct(self._val, other.target)(other._val),
OuterProduct(self._jac(self._val), other.target)._myadd(
OuterProduct(self._val, other.target)(other._jac), False))
if np.isscalar(other):
return self.__mul__(other)
if isinstance(other, (Field, MultiField)):
from .operators.outer_product_operator import OuterProduct
if other.jac is None:
return self.new(OuterProduct(self._val, other.domain)(other),
OuterProduct(self._jac(self._val), other.domain))
return self.new(
OuterProduct(self._val, other.target)(other._val),
OuterProduct(self._jac(self._val), other.target)._myadd(
OuterProduct(self._val, other.target)(other._jac), False))
def vdot(self, other):
"""Computes the inner product of this Linearization with a Field or
......@@ -235,7 +229,7 @@ class Linearization(Operator):
the inner product of self and other
"""
from .operators.simple_linear_operators import VdotOperator
if isinstance(other, (Field, MultiField)):
if other.jac is None:
return self.new(
self._val.vdot(other),
VdotOperator(other)(self._jac))
......@@ -290,9 +284,9 @@ class Linearization(Operator):
def clip(self, a_min=None, a_max=None):
if a_min is None and a_max is None:
return self
if not (a_min is None or np.isscalar(a_min) or isinstance(a_min, (Field, MultiField))):
if not (a_min is None or np.isscalar(a_min) or a_min.jac is None):
return NotImplemented
if not (a_max is None or np.isscalar(a_max) or isinstance(a_max, (Field, MultiField))):
if not (a_max is None or np.isscalar(a_max) or a_max.jac is None):
return NotImplemented
return self.ptw("clip", a_min, a_max)
......
......@@ -128,15 +128,13 @@ class Operator(metaclass=NiftyMeta):
return ContractionOperator(self.target, spaces)(self)
def vdot(self, other):
from ..field import Field
from ..multi_field import MultiField
from ..sugar import makeOp
if isinstance(other, Operator):
if not isinstance(other, Operator):
raise TypeError
if other.jac is None:
res = self.conjugate()*other
elif isinstance(other, (Field, MultiField)):
res = makeOp(other) @ self.conjugate()
else:
raise TypeError
res = makeOp(other) @ self.conjugate()
return res.sum()
@property
......@@ -207,20 +205,16 @@ class Operator(metaclass=NiftyMeta):
return _OpSum(self, -x)
def __pow__(self, power):
from ..field import Field
from ..multi_field import MultiField
if not (np.isscalar(power) or isinstance(power, (Field, MultiField))):
if not (np.isscalar(power) or power.jac is None):
return NotImplemented
return self.ptw("power", power)
def clip(self, a_min=None, a_max=None):
from ..field import Field
from ..multi_field import MultiField
if a_min is None and a_max is None:
return self
if not (a_min is None or np.isscalar(a_min) or isinstance(a_min, (Field, MultiField))):
if not (a_min is None or np.isscalar(a_min) or a_min.jac is None):
return NotImplemented
if not (a_max is None or np.isscalar(a_max) or isinstance(a_max, (Field, MultiField))):
if not (a_max is None or np.isscalar(a_max) or a_max.jac is None):
return NotImplemented
return self.ptw("clip", a_min, a_max)
......@@ -241,13 +235,10 @@ class Operator(metaclass=NiftyMeta):
return self.apply(x.extract(self.domain))
def _check_input(self, x):
from ..field import Field
from ..multi_field import MultiField
from ..linearization import Linearization
from .scaling_operator import ScalingOperator
if not isinstance(x, (Field, MultiField, Linearization)):
if not (isinstance(x, Operator) and x.val is not None):
raise TypeError
if isinstance(x, Linearization):
if x.jac is not None:
if not isinstance(x.jac, ScalingOperator):
raise ValueError
if x.jac._factor != 1:
......@@ -255,12 +246,11 @@ class Operator(metaclass=NiftyMeta):
self._check_domain_equality(self._domain, x.domain)
def __call__(self, x):
from ..linearization import Linearization
from ..field import Field
from ..multi_field import MultiField
if isinstance(x, Linearization):
if not isinstance(x, Operator):
raise TypeError
if x.jac is not None:
return self.apply(x.trivial_jac()).prepend_jac(x.jac)
elif isinstance(x, (Field, MultiField)):
elif x.val is not None:
return self.apply(x)
return self @ x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment