Commit 12c44bb4 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

... and more cleanups

parent 86d7c515
...@@ -38,13 +38,13 @@ class PolynomialResponse(ift.LinearOperator): ...@@ -38,13 +38,13 @@ class PolynomialResponse(ift.LinearOperator):
""" """
def __init__(self, domain, sampling_points): def __init__(self, domain, sampling_points):
super(PolynomialResponse, self).__init__()
if not (isinstance(domain, ift.UnstructuredDomain) if not (isinstance(domain, ift.UnstructuredDomain)
and isinstance(x, np.ndarray)): and isinstance(x, np.ndarray)):
raise TypeError raise TypeError
self._domain = ift.DomainTuple.make(domain) self._domain = ift.DomainTuple.make(domain)
tgt = ift.UnstructuredDomain(sampling_points.shape) tgt = ift.UnstructuredDomain(sampling_points.shape)
self._target = ift.DomainTuple.make(tgt) self._target = ift.DomainTuple.make(tgt)
self._capability = self.TIMES | self.ADJOINT_TIMES
sh = (self.target.size, domain.size) sh = (self.target.size, domain.size)
self._mat = np.empty(sh) self._mat = np.empty(sh)
...@@ -62,10 +62,6 @@ class PolynomialResponse(ift.LinearOperator): ...@@ -62,10 +62,6 @@ class PolynomialResponse(ift.LinearOperator):
out = self._mat.conj().T.dot(val) out = self._mat.conj().T.dot(val)
return ift.from_global_data(self._tgt(mode), out) return ift.from_global_data(self._tgt(mode), out)
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
# Generate some mock data # Generate some mock data
N_params = 10 N_params = 10
......
...@@ -40,7 +40,6 @@ class DOFSpace(StructuredDomain): ...@@ -40,7 +40,6 @@ class DOFSpace(StructuredDomain):
_needed_for_hash = ["_dvol"] _needed_for_hash = ["_dvol"]
def __init__(self, dof_weights): def __init__(self, dof_weights):
super(DOFSpace, self).__init__()
self._dvol = tuple(dof_weights) self._dvol = tuple(dof_weights)
@property @property
......
...@@ -25,10 +25,6 @@ from ..utilities import NiftyMetaBase ...@@ -25,10 +25,6 @@ from ..utilities import NiftyMetaBase
class Domain(NiftyMetaBase()): class Domain(NiftyMetaBase()):
"""The abstract class repesenting a (structured or unstructured) domain. """The abstract class repesenting a (structured or unstructured) domain.
""" """
def __init__(self):
self._hash = None
def __repr__(self): def __repr__(self):
raise NotImplementedError raise NotImplementedError
...@@ -40,7 +36,9 @@ class Domain(NiftyMetaBase()): ...@@ -40,7 +36,9 @@ class Domain(NiftyMetaBase()):
Only members that are explicitly added to Only members that are explicitly added to
:attr:`._needed_for_hash` will be used for hashing. :attr:`._needed_for_hash` will be used for hashing.
""" """
if self._hash is None: try:
return self._hash
except AttributeError:
h = 0 h = 0
for key in self._needed_for_hash: for key in self._needed_for_hash:
h ^= hash(vars(self)[key]) h ^= hash(vars(self)[key])
......
...@@ -43,8 +43,6 @@ class GLSpace(StructuredDomain): ...@@ -43,8 +43,6 @@ class GLSpace(StructuredDomain):
_needed_for_hash = ["_nlat", "_nlon"] _needed_for_hash = ["_nlat", "_nlon"]
def __init__(self, nlat, nlon=None): def __init__(self, nlat, nlon=None):
super(GLSpace, self).__init__()
self._nlat = int(nlat) self._nlat = int(nlat)
if self._nlat < 1: if self._nlat < 1:
raise ValueError("nlat must be a positive number.") raise ValueError("nlat must be a positive number.")
......
...@@ -40,7 +40,6 @@ class HPSpace(StructuredDomain): ...@@ -40,7 +40,6 @@ class HPSpace(StructuredDomain):
_needed_for_hash = ["_nside"] _needed_for_hash = ["_nside"]
def __init__(self, nside): def __init__(self, nside):
super(HPSpace, self).__init__()
self._nside = int(nside) self._nside = int(nside)
if self._nside < 1: if self._nside < 1:
raise ValueError("nside must be >=1.") raise ValueError("nside must be >=1.")
......
...@@ -48,7 +48,6 @@ class LMSpace(StructuredDomain): ...@@ -48,7 +48,6 @@ class LMSpace(StructuredDomain):
_needed_for_hash = ["_lmax", "_mmax"] _needed_for_hash = ["_lmax", "_mmax"]
def __init__(self, lmax, mmax=None): def __init__(self, lmax, mmax=None):
super(LMSpace, self).__init__()
self._lmax = np.int(lmax) self._lmax = np.int(lmax)
if self._lmax < 0: if self._lmax < 0:
raise ValueError("lmax must be >=0.") raise ValueError("lmax must be >=0.")
......
...@@ -32,8 +32,6 @@ class LogRGSpace(StructuredDomain): ...@@ -32,8 +32,6 @@ class LogRGSpace(StructuredDomain):
_needed_for_hash = ['_shape', '_bindistances', '_t_0', '_harmonic'] _needed_for_hash = ['_shape', '_bindistances', '_t_0', '_harmonic']
def __init__(self, shape, bindistances, t_0, harmonic=False): def __init__(self, shape, bindistances, t_0, harmonic=False):
super(LogRGSpace, self).__init__()
self._harmonic = bool(harmonic) self._harmonic = bool(harmonic)
if np.isscalar(shape): if np.isscalar(shape):
......
...@@ -158,8 +158,6 @@ class PowerSpace(StructuredDomain): ...@@ -158,8 +158,6 @@ class PowerSpace(StructuredDomain):
return PowerSpace.linear_binbounds(nbin, lbound, rbound) return PowerSpace.linear_binbounds(nbin, lbound, rbound)
def __init__(self, harmonic_partner, binbounds=None): def __init__(self, harmonic_partner, binbounds=None):
super(PowerSpace, self).__init__()
if not (isinstance(harmonic_partner, StructuredDomain) and if not (isinstance(harmonic_partner, StructuredDomain) and
harmonic_partner.harmonic): harmonic_partner.harmonic):
raise ValueError("harmonic_partner must be a harmonic space.") raise ValueError("harmonic_partner must be a harmonic space.")
......
...@@ -51,8 +51,6 @@ class RGSpace(StructuredDomain): ...@@ -51,8 +51,6 @@ class RGSpace(StructuredDomain):
_needed_for_hash = ["_distances", "_shape", "_harmonic"] _needed_for_hash = ["_distances", "_shape", "_harmonic"]
def __init__(self, shape, distances=None, harmonic=False): def __init__(self, shape, distances=None, harmonic=False):
super(RGSpace, self).__init__()
self._harmonic = bool(harmonic) self._harmonic = bool(harmonic)
if np.isscalar(shape): if np.isscalar(shape):
shape = (shape,) shape = (shape,)
......
...@@ -38,7 +38,6 @@ class UnstructuredDomain(Domain): ...@@ -38,7 +38,6 @@ class UnstructuredDomain(Domain):
_needed_for_hash = ["_shape"] _needed_for_hash = ["_shape"]
def __init__(self, shape): def __init__(self, shape):
super(UnstructuredDomain, self).__init__()
try: try:
self._shape = tuple([int(i) for i in shape]) self._shape = tuple([int(i) for i in shape])
except TypeError: except TypeError:
......
...@@ -134,9 +134,8 @@ class LOSResponse(LinearOperator): ...@@ -134,9 +134,8 @@ class LOSResponse(LinearOperator):
""" """
def __init__(self, domain, starts, ends, sigmas_low=None, sigmas_up=None): def __init__(self, domain, starts, ends, sigmas_low=None, sigmas_up=None):
super(LOSResponse, self).__init__()
self._domain = DomainTuple.make(domain) self._domain = DomainTuple.make(domain)
self._capability = self.TIMES | self.ADJOINT_TIMES
if ((not isinstance(self.domain[0], RGSpace)) or if ((not isinstance(self.domain[0], RGSpace)) or
(len(self._domain) != 1)): (len(self._domain) != 1)):
...@@ -221,10 +220,6 @@ class LOSResponse(LinearOperator): ...@@ -221,10 +220,6 @@ class LOSResponse(LinearOperator):
self._target = DomainTuple.make(UnstructuredDomain(nlos)) self._target = DomainTuple.make(UnstructuredDomain(nlos))
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode): def apply(self, x, mode):
self._check_input(x, mode) self._check_input(x, mode)
if mode == self.TIMES: if mode == self.TIMES:
......
...@@ -42,7 +42,6 @@ class DescentMinimizer(Minimizer): ...@@ -42,7 +42,6 @@ class DescentMinimizer(Minimizer):
""" """
def __init__(self, controller, line_searcher=LineSearchStrongWolfe()): def __init__(self, controller, line_searcher=LineSearchStrongWolfe()):
super(DescentMinimizer, self).__init__()
self._controller = controller self._controller = controller
self.line_searcher = line_searcher self.line_searcher = line_searcher
......
...@@ -47,7 +47,6 @@ class Energy(NiftyMetaBase()): ...@@ -47,7 +47,6 @@ class Energy(NiftyMetaBase()):
""" """
def __init__(self, position): def __init__(self, position):
super(Energy, self).__init__()
self._position = position self._position = position
self._gradnorm = None self._gradnorm = None
......
...@@ -47,7 +47,6 @@ class GradientNormController(IterationController): ...@@ -47,7 +47,6 @@ class GradientNormController(IterationController):
def __init__(self, tol_abs_gradnorm=None, tol_rel_gradnorm=None, def __init__(self, tol_abs_gradnorm=None, tol_rel_gradnorm=None,
convergence_level=1, iteration_limit=None, name=None): convergence_level=1, iteration_limit=None, name=None):
super(GradientNormController, self).__init__()
self._tol_abs_gradnorm = tol_abs_gradnorm self._tol_abs_gradnorm = tol_abs_gradnorm
self._tol_rel_gradnorm = tol_rel_gradnorm self._tol_rel_gradnorm = tol_rel_gradnorm
self._convergence_level = convergence_level self._convergence_level = convergence_level
......
...@@ -50,7 +50,6 @@ class LineEnergy(object): ...@@ -50,7 +50,6 @@ class LineEnergy(object):
""" """
def __init__(self, line_position, energy, line_direction, offset=0.): def __init__(self, line_position, energy, line_direction, offset=0.):
super(LineEnergy, self).__init__()
self._line_position = float(line_position) self._line_position = float(line_position)
self._line_direction = line_direction self._line_direction = line_direction
......
...@@ -74,7 +74,6 @@ class ScipyMinimizer(Minimizer): ...@@ -74,7 +74,6 @@ class ScipyMinimizer(Minimizer):
""" """
def __init__(self, method, options, need_hessp, bounds): def __init__(self, method, options, need_hessp, bounds):
super(ScipyMinimizer, self).__init__()
if not dobj.is_numpy(): if not dobj.is_numpy():
raise NotImplementedError raise NotImplementedError
self._method = method self._method = method
...@@ -130,7 +129,6 @@ def L_BFGS_B(ftol, gtol, maxiter, maxcor=10, disp=False, bounds=None): ...@@ -130,7 +129,6 @@ def L_BFGS_B(ftol, gtol, maxiter, maxcor=10, disp=False, bounds=None):
class ScipyCG(Minimizer): class ScipyCG(Minimizer):
def __init__(self, tol, maxiter): def __init__(self, tol, maxiter):
super(ScipyCG, self).__init__()
if not dobj.is_numpy(): if not dobj.is_numpy():
raise NotImplementedError raise NotImplementedError
self._tol = tol self._tol = tol
......
...@@ -33,21 +33,16 @@ class BlockDiagonalOperator(EndomorphicOperator): ...@@ -33,21 +33,16 @@ class BlockDiagonalOperator(EndomorphicOperator):
dictionary with operators domain names as keys and dictionary with operators domain names as keys and
LinearOperators as items LinearOperators as items
""" """
super(BlockDiagonalOperator, self).__init__()
if not isinstance(domain, MultiDomain): if not isinstance(domain, MultiDomain):
raise TypeError("MultiDomain expected") raise TypeError("MultiDomain expected")
if not isinstance(operators, tuple): if not isinstance(operators, tuple):
raise TypeError("tuple expected") raise TypeError("tuple expected")
self._domain = domain self._domain = domain
self._ops = operators self._ops = operators
self._cap = self._all_ops self._capability = self._all_ops
for op in self._ops: for op in self._ops:
if op is not None: if op is not None:
self._cap &= op.capability self._capability &= op.capability
@property
def capability(self):
return self._cap
def apply(self, x, mode): def apply(self, x, mode):
self._check_input(x, mode) self._check_input(x, mode)
......
...@@ -16,8 +16,6 @@ from .. import dobj ...@@ -16,8 +16,6 @@ from .. import dobj
# highest frequency. # highest frequency.
class CentralZeroPadder(LinearOperator): class CentralZeroPadder(LinearOperator):
def __init__(self, domain, new_shape, space=0): def __init__(self, domain, new_shape, space=0):
super(CentralZeroPadder, self).__init__()
self._domain = DomainTuple.make(domain) self._domain = DomainTuple.make(domain)
self._space = utilities.infer_space(self._domain, space) self._space = utilities.infer_space(self._domain, space)
dom = self._domain[self._space] dom = self._domain[self._space]
...@@ -35,6 +33,7 @@ class CentralZeroPadder(LinearOperator): ...@@ -35,6 +33,7 @@ class CentralZeroPadder(LinearOperator):
self._target = list(self._domain) self._target = list(self._domain)
self._target[self._space] = tgt self._target[self._space] = tgt
self._target = DomainTuple.make(self._target) self._target = DomainTuple.make(self._target)
self._capability = self.TIMES | self.ADJOINT_TIMES
slicer = [] slicer = []
axes = self._target.axes[self._space] axes = self._target.axes[self._space]
...@@ -53,10 +52,6 @@ class CentralZeroPadder(LinearOperator): ...@@ -53,10 +52,6 @@ class CentralZeroPadder(LinearOperator):
self.slicer[i] = tuple(tmp) self.slicer[i] = tuple(tmp)
self.slicer = tuple(self.slicer) self.slicer = tuple(self.slicer)
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode): def apply(self, x, mode):
self._check_input(x, mode) self._check_input(x, mode)
x = x.val x = x.val
......
...@@ -31,7 +31,6 @@ class ChainOperator(LinearOperator): ...@@ -31,7 +31,6 @@ class ChainOperator(LinearOperator):
def __init__(self, ops, _callingfrommake=False): def __init__(self, ops, _callingfrommake=False):
if not _callingfrommake: if not _callingfrommake:
raise NotImplementedError raise NotImplementedError
super(ChainOperator, self).__init__()
self._ops = ops self._ops = ops
self._capability = self._all_ops self._capability = self._all_ops
for op in ops: for op in ops:
...@@ -132,10 +131,6 @@ class ChainOperator(LinearOperator): ...@@ -132,10 +131,6 @@ class ChainOperator(LinearOperator):
return self.make([op._flip_modes(trafo) for op in self._ops]) return self.make([op._flip_modes(trafo) for op in self._ops])
raise ValueError("invalid operator transformation") raise ValueError("invalid operator transformation")
@property
def capability(self):
return self._capability
def apply(self, x, mode): def apply(self, x, mode):
self._check_mode(mode) self._check_mode(mode)
t_ops = self._ops if mode & self._backwards else reversed(self._ops) t_ops = self._ops if mode & self._backwards else reversed(self._ops)
......
...@@ -57,8 +57,6 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -57,8 +57,6 @@ class DiagonalOperator(EndomorphicOperator):
""" """
def __init__(self, diagonal, domain=None, spaces=None): def __init__(self, diagonal, domain=None, spaces=None):
super(DiagonalOperator, self).__init__()
if not isinstance(diagonal, Field): if not isinstance(diagonal, Field):
raise TypeError("Field object required") raise TypeError("Field object required")
if domain is None: if domain is None:
...@@ -99,6 +97,7 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -99,6 +97,7 @@ class DiagonalOperator(EndomorphicOperator):
def _fill_rest(self): def _fill_rest(self):
self._ldiag.flags.writeable = False self._ldiag.flags.writeable = False
self._complex = utilities.iscomplextype(self._ldiag.dtype) self._complex = utilities.iscomplextype(self._ldiag.dtype)
self._capability = self._all_ops
if not self._complex: if not self._complex:
lmin = self._ldiag.min() if self._ldiag.size > 0 else 1. lmin = self._ldiag.min() if self._ldiag.size > 0 else 1.
self._diagmin = dobj.np_allreduce_min(np.array(lmin))[()] self._diagmin = dobj.np_allreduce_min(np.array(lmin))[()]
...@@ -150,10 +149,6 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -150,10 +149,6 @@ class DiagonalOperator(EndomorphicOperator):
return Field.from_local_data(x.domain, x.local_data*xdiag) return Field.from_local_data(x.domain, x.local_data*xdiag)
return Field.from_local_data(x.domain, x.local_data/xdiag) return Field.from_local_data(x.domain, x.local_data/xdiag)
@property
def capability(self):
return self._all_ops
def _flip_modes(self, trafo): def _flip_modes(self, trafo):
xdiag = self._ldiag xdiag = self._ldiag
if self._complex and (trafo & self.ADJOINT_BIT): if self._complex and (trafo & self.ADJOINT_BIT):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment