Commit 12c44bb4 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

... and more cleanups

parent 86d7c515
......@@ -54,8 +54,6 @@ class DOFDistributor(LinearOperator):
"""
def __init__(self, dofdex, target=None, space=None):
super(DOFDistributor, self).__init__()
if target is None:
target = dofdex.domain
self._target = DomainTuple.make(target)
......@@ -98,6 +96,7 @@ class DOFDistributor(LinearOperator):
dom = list(self._target)
dom[self._space] = other_space
self._domain = DomainTuple.make(dom)
self._capability = self.TIMES | self.ADJOINT_TIMES
if dobj.default_distaxis() in self._domain.axes[self._space]:
dofdex = dobj.local_data(dofdex)
......@@ -142,7 +141,3 @@ class DOFDistributor(LinearOperator):
def apply(self, x, mode):
self._check_input(x, mode)
return self._times(x) if mode == self.TIMES else self._adjoint_times(x)
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
......@@ -35,6 +35,7 @@ class DomainDistributor(LinearOperator):
self._domain = [tgt for i, tgt in enumerate(self._target)
if i in self._spaces]
self._domain = DomainTuple.make(self._domain)
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
......@@ -49,7 +50,3 @@ class DomainDistributor(LinearOperator):
else:
return x.sum([s for s in range(len(x.domain))
if s not in self._spaces])
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
......@@ -36,7 +36,6 @@ class EnergyOperator(Operator):
class SquaredNormOperator(EnergyOperator):
def __init__(self, domain):
super(SquaredNormOperator, self).__init__()
self._domain = domain
def apply(self, x):
......@@ -50,7 +49,6 @@ class SquaredNormOperator(EnergyOperator):
class QuadraticFormOperator(EnergyOperator):
def __init__(self, op):
from .endomorphic_operator import EndomorphicOperator
super(QuadraticFormOperator, self).__init__()
if not isinstance(op, EndomorphicOperator):
raise TypeError("op must be an EndomorphicOperator")
self._op = op
......@@ -67,7 +65,6 @@ class QuadraticFormOperator(EnergyOperator):
class GaussianEnergy(EnergyOperator):
def __init__(self, mean=None, covariance=None, domain=None):
super(GaussianEnergy, self).__init__()
self._domain = None
if mean is not None:
self._checkEquivalence(mean.domain)
......@@ -132,7 +129,6 @@ class BernoulliEnergy(EnergyOperator):
class Hamiltonian(EnergyOperator):
def __init__(self, lh, ic_samp=None):
super(Hamiltonian, self).__init__()
self._lh = lh
self._prior = GaussianEnergy(domain=lh.domain)
self._ic_samp = ic_samp
......@@ -156,7 +152,6 @@ class SampledKullbachLeiblerDivergence(EnergyOperator):
h: Hamiltonian
N: Number of samples to be used
"""
super(SampledKullbachLeiblerDivergence, self).__init__()
self._h = h
self._domain = h.domain
self._res_samples = tuple(res_samples)
......
......@@ -33,6 +33,7 @@ from ..utilities import infer_space, special_add_at
class ExpTransform(LinearOperator):
def __init__(self, target, dof, space=0):
self._target = DomainTuple.make(target)
self._capability = self.TIMES | self.ADJOINT_TIMES
self._space = infer_space(self._target, space)
tgt = self._target[self._space]
if not ((isinstance(tgt, RGSpace) and tgt.harmonic) or
......@@ -115,7 +116,3 @@ class ExpTransform(LinearOperator):
if d == ax:
x = dobj.redistribute(x, dist=ax)
return Field(self._tgt(mode), val=x)
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
......@@ -13,7 +13,6 @@ from .. import utilities
class FieldZeroPadder(LinearOperator):
def __init__(self, domain, new_shape, space=0):
super(FieldZeroPadder, self).__init__()
self._domain = DomainTuple.make(domain)
self._space = utilities.infer_space(self._domain, space)
dom = self._domain[self._space]
......@@ -30,10 +29,7 @@ class FieldZeroPadder(LinearOperator):
self._target = list(self._domain)
self._target[self._space] = tgt
self._target = DomainTuple.make(self._target)
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
......
......@@ -57,10 +57,9 @@ class FFTOperator(LinearOperator):
"""
def __init__(self, domain, target=None, space=None):
super(FFTOperator, self).__init__()
# Initialize domain and target
self._domain = DomainTuple.make(domain)
self._capability = self._all_ops
self._space = utilities.infer_space(self._domain, space)
adom = self._domain[self._space]
......@@ -127,10 +126,6 @@ class FFTOperator(LinearOperator):
fct *= self._target[self._space].scalar_dvol
return Tval if fct == 1 else Tval*fct
@property
def capability(self):
return self._all_ops
class HartleyOperator(LinearOperator):
"""Transforms between a pair of position and harmonic RGSpaces.
......@@ -162,10 +157,9 @@ class HartleyOperator(LinearOperator):
"""
def __init__(self, domain, target=None, space=None):
super(HartleyOperator, self).__init__()
# Initialize domain and target
self._domain = DomainTuple.make(domain)
self._capability = self._all_ops
self._space = utilities.infer_space(self._domain, space)
adom = self._domain[self._space]
......@@ -239,10 +233,6 @@ class HartleyOperator(LinearOperator):
fct = self._target[self._space].scalar_dvol
return Tval if fct == 1 else Tval*fct
@property
def capability(self):
return self._all_ops
class SHTOperator(LinearOperator):
"""Transforms between a harmonic domain on the sphere and a position
......@@ -272,10 +262,9 @@ class SHTOperator(LinearOperator):
"""
def __init__(self, domain, target=None, space=None):
super(SHTOperator, self).__init__()
# Initialize domain and target
self._domain = DomainTuple.make(domain)
self._capability = self.TIMES | self.ADJOINT_TIMES
self._space = utilities.infer_space(self._domain, space)
hspc = self._domain[self._space]
......@@ -351,10 +340,6 @@ class SHTOperator(LinearOperator):
odat = dobj.redistribute(odat, dist=dobj.distaxis(x.val))
return Field(tdom, odat)
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
class HarmonicTransformOperator(LinearOperator):
"""Transforms between a harmonic domain and a position domain counterpart.
......@@ -384,8 +369,6 @@ class HarmonicTransformOperator(LinearOperator):
"""
def __init__(self, domain, target=None, space=None):
super(HarmonicTransformOperator, self).__init__()
domain = DomainTuple.make(domain)
space = utilities.infer_space(domain, space)
......@@ -399,15 +382,12 @@ class HarmonicTransformOperator(LinearOperator):
self._op = SHTOperator(domain, target, space)
self._domain = self._op.domain
self._target = self._op.target
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
return self._op.apply(x, mode)
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
def HarmonicSmoothingOperator(domain, sigma, space=None):
""" This function returns an operator that carries out a smoothing with
......
......@@ -51,15 +51,11 @@ class InversionEnabler(EndomorphicOperator):
"""
def __init__(self, op, iteration_controller, approximation=None):
super(InversionEnabler, self).__init__()
self._op = op
self._ic = iteration_controller
self._approximation = approximation
self._domain = op.domain
@property
def capability(self):
return self._addInverse[self._op.capability]
self._capability = self._addInverse[self._op.capability]
def apply(self, x, mode):
self._check_mode(mode)
......
......@@ -47,8 +47,8 @@ class LaplaceOperator(EndomorphicOperator):
"""
def __init__(self, domain, space=None, logarithmic=True):
super(LaplaceOperator, self).__init__()
self._domain = DomainTuple.make(domain)
self._capability = self.TIMES | self.ADJOINT_TIMES
self._space = infer_space(self._domain, space)
if not isinstance(self._domain[self._space], PowerSpace):
......@@ -68,10 +68,6 @@ class LaplaceOperator(EndomorphicOperator):
self._dposc[1:] += self._dpos
self._dposc *= 0.5
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
def _times(self, x):
axes = x.domain.axes[self._space]
axis = axes[0]
......
......@@ -145,7 +145,7 @@ class LinearOperator(Operator):
:attr:`INVERSE_TIMES`, and :attr:`ADJOINT_INVERSE_TIMES`,
joined together by the "|" operator.
"""
raise NotImplementedError
return self._capability
def apply(self, x, mode):
""" Applies the Operator to a given `x`, in a specified `mode`.
......
......@@ -38,6 +38,7 @@ class MaskOperator(LinearOperator):
self._domain = DomainTuple.make(mask.domain)
self._mask = np.logical_not(mask.to_global_data())
self._target = DomainTuple.make(UnstructuredDomain(self._mask.sum()))
self._capability = self.TIMES | self.ADJOINT_TIMES
def data_indices(self):
if len(self.domain.shape) == 1:
......@@ -55,7 +56,3 @@ class MaskOperator(LinearOperator):
res[self._mask] = x
res[~self._mask] = 0
return Field.from_global_data(self.domain, res)
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
......@@ -45,17 +45,13 @@ class OperatorAdapter(LinearOperator):
"""
def __init__(self, op, op_transform):
super(OperatorAdapter, self).__init__()
self._op = op
self._trafo = int(op_transform)
if self._trafo < 1 or self._trafo > 3:
raise ValueError("invalid operator transformation")
self._domain = self._op._dom(1 << self._trafo)
self._target = self._op._tgt(1 << self._trafo)
@property
def capability(self):
return self._capTable[self._trafo][self._op.capability]
self._capability = self._capTable[self._trafo][self._op.capability]
def _flip_modes(self, trafo):
newtrafo = trafo ^ self._trafo
......
......@@ -59,6 +59,7 @@ class QHTOperator(LinearOperator):
self._domain[self._space] = \
self._target[self._space].get_default_codomain()
self._domain = DomainTuple.make(self._domain)
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
......@@ -76,7 +77,3 @@ class QHTOperator(LinearOperator):
if i == ax:
x = dobj.redistribute(x, dist=ax)
return Field(self._tgt(mode), val=x)
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
......@@ -32,7 +32,6 @@ class RelaxedSumOperator(LinearOperator):
"""Class representing sums of operators with compatible domains."""
def __init__(self, ops):
super(RelaxedSumOperator, self).__init__()
self._ops = ops
self._domain = domain_union([op.domain for op in ops])
self._target = domain_union([op.target for op in ops])
......@@ -50,10 +49,6 @@ class RelaxedSumOperator(LinearOperator):
def adjoint(self):
return RelaxedSumOperator([op.adjoint for op in self._ops])
@property
def capability(self):
return self._capability
def apply(self, x, mode):
self._check_input(x, mode)
res = None
......
......@@ -50,13 +50,13 @@ class SamplingEnabler(EndomorphicOperator):
def __init__(self, likelihood, prior, iteration_controller,
approximation=None):
super(SamplingEnabler, self).__init__()
self._op = likelihood + prior
self._likelihood = likelihood
self._prior = prior
self._ic = iteration_controller
self._approximation = approximation
self._domain = self._op.domain
self._capability = self._op.capability
def draw_sample(self, from_inverse=False, dtype=np.float64):
try:
......@@ -77,9 +77,5 @@ class SamplingEnabler(EndomorphicOperator):
energy, convergence = inverter(energy)
return energy.position
@property
def capability(self):
return self._op.capability
def apply(self, x, mode):
return self._op.apply(x, mode)
......@@ -34,11 +34,11 @@ class SandwichOperator(EndomorphicOperator):
def __init__(self, bun, cheese, op, _callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError
super(SandwichOperator, self).__init__()
self._bun = bun
self._cheese = cheese
self._op = op
self._domain = op.domain
self._capability = op._capability
@staticmethod
def make(bun, cheese=None):
......@@ -66,10 +66,6 @@ class SandwichOperator(EndomorphicOperator):
return op
return SandwichOperator(bun, cheese, op, _callingfrommake=True)
@property
def capability(self):
return self._op.capability
def apply(self, x, mode):
return self._op.apply(x, mode)
......
......@@ -54,12 +54,12 @@ class ScalingOperator(EndomorphicOperator):
def __init__(self, factor, domain):
from ..sugar import makeDomain
super(ScalingOperator, self).__init__()
if not np.isscalar(factor):
raise TypeError("Scalar required")
self._factor = factor
self._domain = makeDomain(domain)
self._capability = self._all_ops
def apply(self, x, mode):
self._check_input(x, mode)
......@@ -82,10 +82,6 @@ class ScalingOperator(EndomorphicOperator):
fct = 1./fct
return ScalingOperator(fct, self._domain)
@property
def capability(self):
return self._all_ops
def draw_sample(self, from_inverse=False, dtype=np.float64):
fct = self._factor
if fct.imag != 0. or fct.real < 0.:
......
......@@ -33,14 +33,10 @@ from ..multi_field import MultiField
class VdotOperator(LinearOperator):
def __init__(self, field):
super(VdotOperator, self).__init__()
self._field = field
self._domain = field.domain
self._target = DomainTuple.scalar_domain()
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
......@@ -51,13 +47,9 @@ class VdotOperator(LinearOperator):
class SumReductionOperator(LinearOperator):
def __init__(self, domain):
super(SumReductionOperator, self).__init__()
self._domain = domain
self._target = DomainTuple.scalar_domain()
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
......@@ -68,12 +60,8 @@ class SumReductionOperator(LinearOperator):
class ConjugationOperator(EndomorphicOperator):
def __init__(self, domain):
super(ConjugationOperator, self).__init__()
self._domain = domain
@property
def capability(self):
return self._all_ops
self._capability = self._all_ops
def apply(self, x, mode):
self._check_input(x, mode)
......@@ -82,12 +70,8 @@ class ConjugationOperator(EndomorphicOperator):
class Realizer(EndomorphicOperator):
def __init__(self, domain):
super(Realizer, self).__init__()
self._domain = domain
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
......@@ -99,10 +83,7 @@ class FieldAdapter(LinearOperator):
self._domain = MultiDomain.make(dom)
self._name = name_dom
self._target = dom[name_dom]
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
......@@ -131,14 +112,10 @@ class GeometryRemover(LinearOperator):
"""
def __init__(self, domain):
super(GeometryRemover, self).__init__()
self._domain = DomainTuple.make(domain)
target_list = [UnstructuredDomain(dom.shape) for dom in self._domain]
self._target = DomainTuple.make(target_list)
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
......@@ -162,6 +139,7 @@ class NullOperator(LinearOperator):
from ..sugar import makeDomain
self._domain = makeDomain(domain)
self._target = makeDomain(target)
self._capability = self.TIMES | self.ADJOINT_TIMES
@staticmethod
def _nullfield(dom):
......@@ -176,7 +154,3 @@ class NullOperator(LinearOperator):
if mode == self.TIMES:
return self._nullfield(self._target)
return self._nullfield(self._domain)
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
......@@ -37,6 +37,7 @@ class SlopeOperator(LinearOperator):
self._domain = DomainTuple.make(domain)
self._target = DomainTuple.make(target)
self._capability = self.TIMES | self.ADJOINT_TIMES
if self.domain[0].shape != (len(self.target[0].shape) + 1,):
raise AssertionError("Shape mismatch!")
......@@ -74,7 +75,3 @@ class SlopeOperator(LinearOperator):
for i in range(self.ndim):
res[i] = np.sum(self.pos[i] * xglob) * self._sigmas[i]
return Field.from_global_data(self.domain, res)
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
......@@ -31,7 +31,6 @@ class SumOperator(LinearOperator):
def __init__(self, ops, neg, dom, tgt, _callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError
super(SumOperator, self).__init__()
self._ops = ops
self._neg = neg
self._domain = dom
......@@ -163,10 +162,6 @@ class SumOperator(LinearOperator):
def adjoint(self):
return self.make([op.adjoint for op in self._ops], self._neg)
@property
def capability(self):
return self._capability
def apply(self, x, mode):
self._check_mode(mode)
res = None
......
......@@ -30,6 +30,7 @@ from .. import utilities
class SymmetrizingOperator(EndomorphicOperator):
def __init__(self, domain, space=0):
self._domain = DomainTuple.make(domain)
self._capability = self.TIMES | self.ADJOINT_TIMES
self._space = utilities.infer_space(self._domain, space)
dom = self._domain[self._space]
if not (isinstance(dom, LogRGSpace) and not dom.harmonic):
......@@ -48,7 +49,3 @@ class SymmetrizingOperator(EndomorphicOperator):
if i == ax:
tmp = dobj.redistribute(tmp, dist=ax)
return Field(self.target, val=tmp)
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment