Commit 12c44bb4 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

... and more cleanups

parent 86d7c515
......@@ -38,13 +38,13 @@ class PolynomialResponse(ift.LinearOperator):
"""
def __init__(self, domain, sampling_points):
super(PolynomialResponse, self).__init__()
if not (isinstance(domain, ift.UnstructuredDomain)
and isinstance(x, np.ndarray)):
raise TypeError
self._domain = ift.DomainTuple.make(domain)
tgt = ift.UnstructuredDomain(sampling_points.shape)
self._target = ift.DomainTuple.make(tgt)
self._capability = self.TIMES | self.ADJOINT_TIMES
sh = (self.target.size, domain.size)
self._mat = np.empty(sh)
......@@ -62,10 +62,6 @@ class PolynomialResponse(ift.LinearOperator):
out = self._mat.conj().T.dot(val)
return ift.from_global_data(self._tgt(mode), out)
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
# Generate some mock data
N_params = 10
......
......@@ -40,7 +40,6 @@ class DOFSpace(StructuredDomain):
_needed_for_hash = ["_dvol"]
def __init__(self, dof_weights):
super(DOFSpace, self).__init__()
self._dvol = tuple(dof_weights)
@property
......
......@@ -25,10 +25,6 @@ from ..utilities import NiftyMetaBase
class Domain(NiftyMetaBase()):
"""The abstract class repesenting a (structured or unstructured) domain.
"""
def __init__(self):
self._hash = None
def __repr__(self):
raise NotImplementedError
......@@ -40,7 +36,9 @@ class Domain(NiftyMetaBase()):
Only members that are explicitly added to
:attr:`._needed_for_hash` will be used for hashing.
"""
if self._hash is None:
try:
return self._hash
except AttributeError:
h = 0
for key in self._needed_for_hash:
h ^= hash(vars(self)[key])
......
......@@ -43,8 +43,6 @@ class GLSpace(StructuredDomain):
_needed_for_hash = ["_nlat", "_nlon"]
def __init__(self, nlat, nlon=None):
super(GLSpace, self).__init__()
self._nlat = int(nlat)
if self._nlat < 1:
raise ValueError("nlat must be a positive number.")
......
......@@ -40,7 +40,6 @@ class HPSpace(StructuredDomain):
_needed_for_hash = ["_nside"]
def __init__(self, nside):
super(HPSpace, self).__init__()
self._nside = int(nside)
if self._nside < 1:
raise ValueError("nside must be >=1.")
......
......@@ -48,7 +48,6 @@ class LMSpace(StructuredDomain):
_needed_for_hash = ["_lmax", "_mmax"]
def __init__(self, lmax, mmax=None):
super(LMSpace, self).__init__()
self._lmax = np.int(lmax)
if self._lmax < 0:
raise ValueError("lmax must be >=0.")
......
......@@ -32,8 +32,6 @@ class LogRGSpace(StructuredDomain):
_needed_for_hash = ['_shape', '_bindistances', '_t_0', '_harmonic']
def __init__(self, shape, bindistances, t_0, harmonic=False):
super(LogRGSpace, self).__init__()
self._harmonic = bool(harmonic)
if np.isscalar(shape):
......
......@@ -158,8 +158,6 @@ class PowerSpace(StructuredDomain):
return PowerSpace.linear_binbounds(nbin, lbound, rbound)
def __init__(self, harmonic_partner, binbounds=None):
super(PowerSpace, self).__init__()
if not (isinstance(harmonic_partner, StructuredDomain) and
harmonic_partner.harmonic):
raise ValueError("harmonic_partner must be a harmonic space.")
......
......@@ -51,8 +51,6 @@ class RGSpace(StructuredDomain):
_needed_for_hash = ["_distances", "_shape", "_harmonic"]
def __init__(self, shape, distances=None, harmonic=False):
super(RGSpace, self).__init__()
self._harmonic = bool(harmonic)
if np.isscalar(shape):
shape = (shape,)
......
......@@ -38,7 +38,6 @@ class UnstructuredDomain(Domain):
_needed_for_hash = ["_shape"]
def __init__(self, shape):
super(UnstructuredDomain, self).__init__()
try:
self._shape = tuple([int(i) for i in shape])
except TypeError:
......
......@@ -134,9 +134,8 @@ class LOSResponse(LinearOperator):
"""
def __init__(self, domain, starts, ends, sigmas_low=None, sigmas_up=None):
super(LOSResponse, self).__init__()
self._domain = DomainTuple.make(domain)
self._capability = self.TIMES | self.ADJOINT_TIMES
if ((not isinstance(self.domain[0], RGSpace)) or
(len(self._domain) != 1)):
......@@ -221,10 +220,6 @@ class LOSResponse(LinearOperator):
self._target = DomainTuple.make(UnstructuredDomain(nlos))
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
......
......@@ -42,7 +42,6 @@ class DescentMinimizer(Minimizer):
"""
def __init__(self, controller, line_searcher=LineSearchStrongWolfe()):
super(DescentMinimizer, self).__init__()
self._controller = controller
self.line_searcher = line_searcher
......
......@@ -47,7 +47,6 @@ class Energy(NiftyMetaBase()):
"""
def __init__(self, position):
super(Energy, self).__init__()
self._position = position
self._gradnorm = None
......
......@@ -47,7 +47,6 @@ class GradientNormController(IterationController):
def __init__(self, tol_abs_gradnorm=None, tol_rel_gradnorm=None,
convergence_level=1, iteration_limit=None, name=None):
super(GradientNormController, self).__init__()
self._tol_abs_gradnorm = tol_abs_gradnorm
self._tol_rel_gradnorm = tol_rel_gradnorm
self._convergence_level = convergence_level
......
......@@ -50,7 +50,6 @@ class LineEnergy(object):
"""
def __init__(self, line_position, energy, line_direction, offset=0.):
super(LineEnergy, self).__init__()
self._line_position = float(line_position)
self._line_direction = line_direction
......
......@@ -74,7 +74,6 @@ class ScipyMinimizer(Minimizer):
"""
def __init__(self, method, options, need_hessp, bounds):
super(ScipyMinimizer, self).__init__()
if not dobj.is_numpy():
raise NotImplementedError
self._method = method
......@@ -130,7 +129,6 @@ def L_BFGS_B(ftol, gtol, maxiter, maxcor=10, disp=False, bounds=None):
class ScipyCG(Minimizer):
def __init__(self, tol, maxiter):
super(ScipyCG, self).__init__()
if not dobj.is_numpy():
raise NotImplementedError
self._tol = tol
......
......@@ -33,21 +33,16 @@ class BlockDiagonalOperator(EndomorphicOperator):
dictionary with operators domain names as keys and
LinearOperators as items
"""
super(BlockDiagonalOperator, self).__init__()
if not isinstance(domain, MultiDomain):
raise TypeError("MultiDomain expected")
if not isinstance(operators, tuple):
raise TypeError("tuple expected")
self._domain = domain
self._ops = operators
self._cap = self._all_ops
self._capability = self._all_ops
for op in self._ops:
if op is not None:
self._cap &= op.capability
@property
def capability(self):
return self._cap
self._capability &= op.capability
def apply(self, x, mode):
self._check_input(x, mode)
......
......@@ -16,8 +16,6 @@ from .. import dobj
# highest frequency.
class CentralZeroPadder(LinearOperator):
def __init__(self, domain, new_shape, space=0):
super(CentralZeroPadder, self).__init__()
self._domain = DomainTuple.make(domain)
self._space = utilities.infer_space(self._domain, space)
dom = self._domain[self._space]
......@@ -35,6 +33,7 @@ class CentralZeroPadder(LinearOperator):
self._target = list(self._domain)
self._target[self._space] = tgt
self._target = DomainTuple.make(self._target)
self._capability = self.TIMES | self.ADJOINT_TIMES
slicer = []
axes = self._target.axes[self._space]
......@@ -53,10 +52,6 @@ class CentralZeroPadder(LinearOperator):
self.slicer[i] = tuple(tmp)
self.slicer = tuple(self.slicer)
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
x = x.val
......
......@@ -31,7 +31,6 @@ class ChainOperator(LinearOperator):
def __init__(self, ops, _callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError
super(ChainOperator, self).__init__()
self._ops = ops
self._capability = self._all_ops
for op in ops:
......@@ -132,10 +131,6 @@ class ChainOperator(LinearOperator):
return self.make([op._flip_modes(trafo) for op in self._ops])
raise ValueError("invalid operator transformation")
@property
def capability(self):
return self._capability
def apply(self, x, mode):
self._check_mode(mode)
t_ops = self._ops if mode & self._backwards else reversed(self._ops)
......
......@@ -57,8 +57,6 @@ class DiagonalOperator(EndomorphicOperator):
"""
def __init__(self, diagonal, domain=None, spaces=None):
super(DiagonalOperator, self).__init__()
if not isinstance(diagonal, Field):
raise TypeError("Field object required")
if domain is None:
......@@ -99,6 +97,7 @@ class DiagonalOperator(EndomorphicOperator):
def _fill_rest(self):
self._ldiag.flags.writeable = False
self._complex = utilities.iscomplextype(self._ldiag.dtype)
self._capability = self._all_ops
if not self._complex:
lmin = self._ldiag.min() if self._ldiag.size > 0 else 1.
self._diagmin = dobj.np_allreduce_min(np.array(lmin))[()]
......@@ -150,10 +149,6 @@ class DiagonalOperator(EndomorphicOperator):
return Field.from_local_data(x.domain, x.local_data*xdiag)
return Field.from_local_data(x.domain, x.local_data/xdiag)
@property
def capability(self):
return self._all_ops
def _flip_modes(self, trafo):
xdiag = self._ldiag
if self._complex and (trafo & self.ADJOINT_BIT):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment