Commit 96ea43f4 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

no more abc

parent d0275f10
...@@ -18,8 +18,6 @@ ...@@ -18,8 +18,6 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import abc
from ..compat import * from ..compat import *
from ..utilities import NiftyMetaBase from ..utilities import NiftyMetaBase
...@@ -31,7 +29,6 @@ class Domain(NiftyMetaBase()): ...@@ -31,7 +29,6 @@ class Domain(NiftyMetaBase()):
def __init__(self): def __init__(self):
self._hash = None self._hash = None
@abc.abstractmethod
def __repr__(self): def __repr__(self):
raise NotImplementedError raise NotImplementedError
...@@ -84,7 +81,7 @@ class Domain(NiftyMetaBase()): ...@@ -84,7 +81,7 @@ class Domain(NiftyMetaBase()):
"""Returns the opposite of :meth:`.__eq__()`""" """Returns the opposite of :meth:`.__eq__()`"""
return not self.__eq__(x) return not self.__eq__(x)
@abc.abstractproperty @property
def shape(self): def shape(self):
"""tuple of int: number of pixels along each axis """tuple of int: number of pixels along each axis
...@@ -103,7 +100,7 @@ class Domain(NiftyMetaBase()): ...@@ -103,7 +100,7 @@ class Domain(NiftyMetaBase()):
from ..dobj import local_shape from ..dobj import local_shape
return local_shape(self.shape) return local_shape(self.shape)
@abc.abstractproperty @property
def size(self): def size(self):
"""int: total number of pixels. """int: total number of pixels.
......
...@@ -18,8 +18,6 @@ ...@@ -18,8 +18,6 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import abc
import numpy as np import numpy as np
from ..compat import * from ..compat import *
...@@ -34,7 +32,7 @@ class StructuredDomain(Domain): ...@@ -34,7 +32,7 @@ class StructuredDomain(Domain):
are needed for power spectrum analysis and smoothing. are needed for power spectrum analysis and smoothing.
""" """
@abc.abstractproperty @property
def scalar_dvol(self): def scalar_dvol(self):
"""float or None : uniform cell volume, if applicable """float or None : uniform cell volume, if applicable
...@@ -63,7 +61,7 @@ class StructuredDomain(Domain): ...@@ -63,7 +61,7 @@ class StructuredDomain(Domain):
tmp = self.dvol tmp = self.dvol
return self.size * tmp if np.isscalar(tmp) else np.sum(tmp) return self.size * tmp if np.isscalar(tmp) else np.sum(tmp)
@abc.abstractproperty @property
def harmonic(self): def harmonic(self):
"""bool : True iff this domain is a harmonic domain.""" """bool : True iff this domain is a harmonic domain."""
raise NotImplementedError raise NotImplementedError
......
...@@ -18,8 +18,6 @@ ...@@ -18,8 +18,6 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import abc
from ..compat import * from ..compat import *
from ..logger import logger from ..logger import logger
from .line_search_strong_wolfe import LineSearchStrongWolfe from .line_search_strong_wolfe import LineSearchStrongWolfe
...@@ -108,7 +106,6 @@ class DescentMinimizer(Minimizer): ...@@ -108,7 +106,6 @@ class DescentMinimizer(Minimizer):
def reset(self): def reset(self):
pass pass
@abc.abstractmethod
def get_descent_direction(self, energy): def get_descent_direction(self, energy):
""" Calculates the next descent direction. """ Calculates the next descent direction.
......
...@@ -18,8 +18,6 @@ ...@@ -18,8 +18,6 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import abc
from ..compat import * from ..compat import *
from ..utilities import NiftyMetaBase from ..utilities import NiftyMetaBase
...@@ -44,7 +42,6 @@ class IterationController(NiftyMetaBase()): ...@@ -44,7 +42,6 @@ class IterationController(NiftyMetaBase()):
CONVERGED, CONTINUE, ERROR = list(range(3)) CONVERGED, CONTINUE, ERROR = list(range(3))
@abc.abstractmethod
def start(self, energy): def start(self, energy):
"""Starts the iteration. """Starts the iteration.
...@@ -59,7 +56,6 @@ class IterationController(NiftyMetaBase()): ...@@ -59,7 +56,6 @@ class IterationController(NiftyMetaBase()):
""" """
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod
def check(self, energy): def check(self, energy):
"""Checks the state of the iteration. Called after every step. """Checks the state of the iteration. Called after every step.
......
...@@ -18,8 +18,6 @@ ...@@ -18,8 +18,6 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import abc
from ..compat import * from ..compat import *
from ..utilities import NiftyMetaBase from ..utilities import NiftyMetaBase
...@@ -37,7 +35,6 @@ class LineSearch(NiftyMetaBase()): ...@@ -37,7 +35,6 @@ class LineSearch(NiftyMetaBase()):
def __init__(self, preferred_initial_step_size=None): def __init__(self, preferred_initial_step_size=None):
self.preferred_initial_step_size = preferred_initial_step_size self.preferred_initial_step_size = preferred_initial_step_size
@abc.abstractmethod
def perform_line_search(self, energy, pk, f_k_minus_1=None): def perform_line_search(self, energy, pk, f_k_minus_1=None):
"""Find step size and advance. """Find step size and advance.
......
...@@ -18,8 +18,6 @@ ...@@ -18,8 +18,6 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import abc
from ..compat import * from ..compat import *
from ..utilities import NiftyMetaBase from ..utilities import NiftyMetaBase
...@@ -28,7 +26,6 @@ class Minimizer(NiftyMetaBase()): ...@@ -28,7 +26,6 @@ class Minimizer(NiftyMetaBase()):
""" A base class used by all minimizers.""" """ A base class used by all minimizers."""
# MR FIXME: the docstring is partially ignored by Sphinx. Why? # MR FIXME: the docstring is partially ignored by Sphinx. Why?
@abc.abstractmethod
def __call__(self, energy, preconditioner=None): def __call__(self, energy, preconditioner=None):
""" Performs the minimization of the provided Energy functional. """ Performs the minimization of the provided Energy functional.
......
...@@ -18,8 +18,6 @@ ...@@ -18,8 +18,6 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import abc
import numpy as np import numpy as np
from ..compat import * from ..compat import *
...@@ -179,7 +177,7 @@ class LinearOperator(Operator): ...@@ -179,7 +177,7 @@ class LinearOperator(Operator):
other = self._toOperator(other, self.domain) other = self._toOperator(other, self.domain)
return SumOperator.make([other, self], [False, True]) return SumOperator.make([other, self], [False, True])
@abc.abstractproperty @property
def capability(self): def capability(self):
"""int : the supported operation modes """int : the supported operation modes
...@@ -189,7 +187,6 @@ class LinearOperator(Operator): ...@@ -189,7 +187,6 @@ class LinearOperator(Operator):
""" """
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod
def apply(self, x, mode): def apply(self, x, mode):
""" Applies the Operator to a given `x`, in a specified `mode`. """ Applies the Operator to a given `x`, in a specified `mode`.
......
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import abc
from ..compat import * from ..compat import *
from ..utilities import NiftyMetaBase from ..utilities import NiftyMetaBase
...@@ -10,14 +9,14 @@ class Operator(NiftyMetaBase()): ...@@ -10,14 +9,14 @@ class Operator(NiftyMetaBase()):
domain, and can also provide the Jacobian. domain, and can also provide the Jacobian.
""" """
@abc.abstractproperty @property
def domain(self): def domain(self):
"""DomainTuple or MultiDomain : the operator's input domain """DomainTuple or MultiDomain : the operator's input domain
The domain on which the Operator's input Field lives.""" The domain on which the Operator's input Field lives."""
raise NotImplementedError raise NotImplementedError
@abc.abstractproperty @property
def target(self): def target(self):
"""DomainTuple or MultiDomain : the operator's output domain """DomainTuple or MultiDomain : the operator's output domain
...@@ -157,3 +156,46 @@ class _OpSum(_CombinedOperator): ...@@ -157,3 +156,46 @@ class _OpSum(_CombinedOperator):
def __call__(self, x): def __call__(self, x):
raise NotImplementedError raise NotImplementedError
class SquaredNormOperator(Operator):
def __init__(self, domain):
super(SquaredNormOperator, self).__init__()
self._domain = domain
self._target = DomainTuple.scalar_domain()
@property
def domain(self):
return self._domain
@property
def target(self):
return self._target
def __call__(self, x):
return Field(self._target, x.vdot(x))
class QuadraticFormOperator(Operator):
def __init__(self, op):
from .endomorphic_operator import EndomorphicOperator
super(QuadraticFormOperator, self).__init__()
if not isinstance(op, EndomorphicOperator):
raise TypeError("op must be an EndomorphicOperator")
self._op = op
self._target = DomainTuple.scalar_domain()
@property
def domain(self):
return self._op.domain
@property
def target(self):
return self._target
def __call__(self, x):
if isinstance(x, Linearization):
jac = self._op(x)
val = Field(self._target, 0.5 * x.vdot(jac))
return Linearization(val, jac)
return Field(self._target, 0.5 * x.vdot(self._op(x)))
...@@ -39,6 +39,12 @@ class RelaxedSumOperator(LinearOperator): ...@@ -39,6 +39,12 @@ class RelaxedSumOperator(LinearOperator):
self._capability = self.TIMES | self.ADJOINT_TIMES self._capability = self.TIMES | self.ADJOINT_TIMES
for op in ops: for op in ops:
self._capability &= op.capability self._capability &= op.capability
#self._ops = []
#for op in ops:
# if isinstance(op, RelaxedSumOperator):
# self._ops += op._ops
# else:
# self._ops += [op]
@property @property
def domain(self): def domain(self):
......
...@@ -24,6 +24,7 @@ from ..compat import * ...@@ -24,6 +24,7 @@ from ..compat import *
from ..domain_tuple import DomainTuple from ..domain_tuple import DomainTuple
from ..domains.unstructured_domain import UnstructuredDomain from ..domains.unstructured_domain import UnstructuredDomain
from .linear_operator import LinearOperator from .linear_operator import LinearOperator
from .endomorphic_operator import EndomorphicOperator
from ..sugar import full from ..sugar import full
from ..field import Field from ..field import Field
...@@ -76,3 +77,39 @@ class SumReductionOperator(LinearOperator): ...@@ -76,3 +77,39 @@ class SumReductionOperator(LinearOperator):
if mode == self.TIMES: if mode == self.TIMES:
return Field(self._target, x.sum()) return Field(self._target, x.sum())
return full(self._domain, x.local_data[()]) return full(self._domain, x.local_data[()])
class ConjugationOperator(EndomorphicOperator):
def __init__(self, domain):
super(ConjugationOperator, self).__init__()
self._domain = domain
@property
def domain(self):
return self._domain
@property
def capability(self):
return self._all_ops
def apply(self, x, mode):
self._check_input(x, mode)
return x.conjugate()
class Realizer(EndomorphicOperator):
def __init__(self, domain):
super(Realizer, self).__init__()
self._domain = domain
@property
def domain(self):
return self._domain
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
return x.real
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import abc
import collections import collections
from itertools import product from itertools import product
...@@ -178,7 +177,7 @@ class _DocStringInheritor(type): ...@@ -178,7 +177,7 @@ class _DocStringInheritor(type):
bases, clsdict) bases, clsdict)
class NiftyMeta(_DocStringInheritor, abc.ABCMeta): class NiftyMeta(_DocStringInheritor):
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment