Commit 664d97fb authored by Theo Steininger's avatar Theo Steininger

Added copy method to Operators.

parent baa71d88
Pipeline #17403 passed with stages
in 25 minutes and 18 seconds
......@@ -33,6 +33,23 @@ class CriticalPowerCurvature(InvertibleOperatorMixin, EndomorphicOperator):
preconditioner=preconditioner,
**kwargs)
def _add_attributes_to_copy(self, copy, **kwargs):
copy._domain = self._domain
if 'theta' in kwargs:
theta = kwargs['theta']
copy.theta = DiagonalOperator(theta.domain, diagonal=theta)
else:
copy.theta = self.theta.copy()
if 'T' in kwargs:
copy.T = kwargs['T']
else:
copy.T = self.T
copy = super(CriticalPowerCurvature,
self)._add_attributes_to_copy(copy, **kwargs)
return copy
def _times(self, x, spaces):
return self.T(x) + self.theta(x)
......
......@@ -54,7 +54,8 @@ class CriticalPowerEnergy(Energy):
# ---Overwritten properties and methods---
def __init__(self, position, m, D=None, alpha=1.0, q=0.,
smoothness_prior=0., logarithmic=True, samples=3, w=None):
smoothness_prior=0., logarithmic=True, samples=3, w=None,
old_curvature=None):
super(CriticalPowerEnergy, self).__init__(position=position)
self.m = m
self.D = D
......@@ -66,6 +67,8 @@ class CriticalPowerEnergy(Energy):
logarithmic=logarithmic)
self.rho = self.position.domain[0].rho
self._w = w if w is not None else None
self._old_curvature = old_curvature
self._curvature = None
# ---Mandatory properties and methods---
......@@ -73,9 +76,11 @@ class CriticalPowerEnergy(Energy):
return self.__class__(position, self.m, D=self.D, alpha=self.alpha,
q=self.q, smoothness_prior=self.smoothness_prior,
logarithmic=self.logarithmic,
w=self.w, samples=self.samples)
w=self.w, samples=self.samples,
old_curvature=self._curvature)
@property
@memo
def value(self):
energy = self._theta.sum()
energy += self.position.vdot(self._rho_prime, bare=True)
......@@ -83,6 +88,7 @@ class CriticalPowerEnergy(Energy):
return energy.real
@property
@memo
def gradient(self):
gradient = -self._theta.weight(-1)
gradient += (self._rho_prime).weight(-1)
......@@ -92,9 +98,14 @@ class CriticalPowerEnergy(Energy):
@property
def curvature(self):
curvature = CriticalPowerCurvature(theta=self._theta.weight(-1),
T=self.T)
return curvature
if self._curvature is None:
if self._old_curvature is None:
self._curvature = CriticalPowerCurvature(
theta=self._theta.weight(-1), T=self.T)
else:
self._curvature = self._old_curvature.copy(
theta=self._theta.weight(-1), T=self.T)
return self._curvature
# ---Added properties and methods---
......
......@@ -48,6 +48,23 @@ class LogNormalWienerFilterCurvature(InvertibleOperatorMixin,
preconditioner=preconditioner,
**kwargs)
def _add_attributes_to_copy(self, copy, **kwargs):
copy._cache = {}
copy._domain = self._domain
copy.R = self.R.copy()
copy.N = self.N.copy()
copy.S = self.S.copy()
copy.d = self.d.copy()
if 'position' in kwargs:
copy.position = kwargs['position']
else:
copy.position = self.position.copy()
copy._fft = self._fft
copy = super(LogNormalWienerFilterCurvature,
self)._add_attributes_to_copy(copy, **kwargs)
return copy
@property
def domain(self):
return self._domain
......
......@@ -24,7 +24,7 @@ class LogNormalWienerFilterEnergy(Energy):
The prior signal covariance in harmonic space.
"""
def __init__(self, position, d, R, N, S, fft4exp=None):
def __init__(self, position, d, R, N, S, fft4exp=None, old_curvature=None):
super(LogNormalWienerFilterEnergy, self).__init__(position=position)
self.d = d
self.R = R
......@@ -37,9 +37,13 @@ class LogNormalWienerFilterEnergy(Energy):
else:
self._fft = fft4exp
self._old_curvature = old_curvature
self._curvature = None
def at(self, position):
return self.__class__(position=position, d=self.d, R=self.R, N=self.N,
S=self.S, fft4exp=self._fft)
S=self.S, fft4exp=self._fft,
old_curvature=self._curvature)
@property
@memo
......@@ -53,11 +57,20 @@ class LogNormalWienerFilterEnergy(Energy):
return self._Sp + self._exppRNRexppd
@property
@memo
def curvature(self):
return LogNormalWienerFilterCurvature(R=self.R, N=self.N, S=self.S,
d=self.d, position=self.position,
fft4exp=self._fft)
if self._curvature is None:
if self._old_curvature is None:
self._curvature = LogNormalWienerFilterCurvature(
R=self.R,
N=self.N,
S=self.S,
d=self.d,
position=self.position,
fft4exp=self._fft)
else:
self._curvature = \
self._old_curvature.copy(position=self.position)
return self._curvature
@property
def _expp(self):
......
......@@ -35,6 +35,15 @@ class WienerFilterCurvature(InvertibleOperatorMixin, EndomorphicOperator):
preconditioner=preconditioner,
**kwargs)
def _add_attributes_to_copy(self, copy, **kwargs):
copy._domain = self._domain
copy.R = self.R.copy()
copy.N = self.N.copy()
copy.S = self.S.copy()
copy = super(WienerFilterCurvature, self)._add_attributes_to_copy(
copy, **kwargs)
return copy
@property
def domain(self):
return self._domain
......
......@@ -23,16 +23,17 @@ class WienerFilterEnergy(Energy):
The prior signal covariance in harmonic space.
"""
def __init__(self, position, d, R, N, S):
def __init__(self, position, d, R, N, S, old_curvature=None):
super(WienerFilterEnergy, self).__init__(position=position)
self.d = d
self.R = R
self.N = N
self.S = S
self._curvature = old_curvature
def at(self, position):
return self.__class__(position=position, d=self.d, R=self.R, N=self.N,
S=self.S)
S=self.S, old_curvature=self.curvature)
@property
@memo
......@@ -45,9 +46,12 @@ class WienerFilterEnergy(Energy):
return self._Dx - self._j
@property
@memo
def curvature(self):
return WienerFilterCurvature(R=self.R, N=self.N, S=self.S)
if self._curvature is None:
self._curvature = WienerFilterCurvature(R=self.R,
N=self.N,
S=self.S)
return self._curvature
@property
@memo
......
......@@ -91,6 +91,14 @@ class ComposedOperator(LinearOperator):
"instances of the LinearOperator-baseclass")
self._operator_store += (op,)
def _add_attributes_to_copy(self, copy, **kwargs):
copy._operator_store = ()
for op in self._operator_store:
copy._operator_store += (op.copy(),)
copy = super(ComposedOperator, self)._add_attributes_to_copy(
copy, **kwargs)
return copy
def _check_input_compatibility(self, x, spaces, inverse=False):
"""
The input check must be disabled for the ComposedOperator, since it
......
......@@ -117,8 +117,20 @@ class DiagonalOperator(EndomorphicOperator):
distribution_strategy=distribution_strategy,
val=diagonal)
self._self_adjoint = None
self._unitary = None
self.set_diagonal(diagonal=diagonal, bare=bare, copy=copy)
def _add_attributes_to_copy(self, copy, **kwargs):
copy._domain = self._domain
copy._distribution_strategy = self._distribution_strategy
copy.set_diagonal(diagonal=self.diagonal(bare=True), bare=True)
copy._self_adjoint = self._self_adjoint
copy._unitary = self._unitary
copy = super(DiagonalOperator, self)._add_attributes_to_copy(copy,
**kwargs)
return copy
def _times(self, x, spaces):
return self._times_helper(x, spaces, operation=lambda z: z.__mul__)
......@@ -127,7 +139,8 @@ class DiagonalOperator(EndomorphicOperator):
operation=lambda z: z.adjoint().__mul__)
def _inverse_times(self, x, spaces):
return self._times_helper(x, spaces, operation=lambda z: z.__rtruediv__)
return self._times_helper(
x, spaces, operation=lambda z: z.__rtruediv__)
def _adjoint_inverse_times(self, x, spaces):
return self._times_helper(x, spaces,
......
......@@ -148,6 +148,17 @@ class FFTOperator(LinearOperator):
self.target_dtype = \
None if target_dtype is None else np.dtype(target_dtype)
def _add_attributes_to_copy(self, copy, **kwargs):
copy._domain = self._domain
copy._target = self._target
copy._forward_transformation = self._forward_transformation
copy._backward_transformation = self._backward_transformation
copy.domain_dtype = self.domain_dtype
copy.target_dtype = self.target_dtype
copy = super(FFTOperator, self)._add_attributes_to_copy(copy,
**kwargs)
return copy
def _times(self, x, spaces):
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
if spaces is None:
......
......@@ -73,6 +73,23 @@ class InvertibleOperatorMixin(object):
self.__backward_x0 = backward_x0
super(InvertibleOperatorMixin, self).__init__(*args, **kwargs)
def _add_attributes_to_copy(self, copy, **kwargs):
copy.__preconditioner = self.__preconditioner
copy.__inverter = self.__inverter
try:
copy.__forward_x0 = self.__forward_x0.copy()
except AttributeError:
copy.__forward_x0 = self.__forward_x0
try:
copy.__backward_x0 = self.__backward_x0.copy()
except AttributeError:
copy.__backward_x0 = self.__backward_x0
copy = super(InvertibleOperatorMixin, self)._add_attributes_to_copy(
copy, **kwargs)
return copy
def _times(self, x, spaces):
if self.__forward_x0 is not None:
x0 = self.__forward_x0
......
......@@ -64,6 +64,15 @@ class LaplaceOperator(EndomorphicOperator):
self._dposc[1:] += self._dpos
self._dposc *= 0.5
def _add_attributes_to_copy(self, copy, **kwargs):
copy._domain = self._domain
copy._logarithmic = self._logarithmic
copy._dpos = self._dpos
copy._dposc = self._dposc
copy = super(LaplaceOperator, self)._add_attributes_to_copy(copy,
**kwargs)
return copy
@property
def target(self):
return self._domain
......
......@@ -26,7 +26,8 @@ from ... import nifty_utilities as utilities
from future.utils import with_metaclass
class LinearOperator(with_metaclass(NiftyMeta, type('NewBase', (Loggable, object), {}))):
class LinearOperator(
with_metaclass(NiftyMeta, type('NewBase', (Loggable, object), {}))):
"""NIFTY base class for linear operators.
The base NIFTY operator class is an abstract class from which
......@@ -75,6 +76,20 @@ class LinearOperator(with_metaclass(NiftyMeta, type('NewBase', (Loggable, object
def __init__(self, default_spaces=None):
self._default_spaces = default_spaces
def copy(self, **kwargs):
class EmptyCopy(self.__class__):
def __init__(self):
pass
result = EmptyCopy()
result.__class__ = self.__class__
result = self._add_attributes_to_copy(result, **kwargs)
return result
def _add_attributes_to_copy(self, copy, **kwargs):
copy._default_spaces = self.default_spaces
return copy
@staticmethod
def _parse_domain(domain):
return utilities.parse_domain(domain)
......
......@@ -87,6 +87,13 @@ class ProjectionOperator(EndomorphicOperator):
self._projection_field = projection_field
self._unitary = None
def _add_attributes_to_copy(self, copy, **kwargs):
copy._projection_field = self._projection_field
copy._unitary = self._unitary
copy = super(ProjectionOperator, self)._add_attributes_to_copy(
copy, **kwargs)
return copy
def _times(self, x, spaces):
# if the domain matches directly
# -> multiply the fields directly
......
......@@ -82,7 +82,7 @@ class ResponseOperator(LinearOperator):
for ii in range(len(kernel_smoothing)):
kernel_smoothing[ii] = SmoothingOperator.make(self._domain[ii],
sigma=sigma[ii])
sigma=sigma[ii])
kernel_exposure[ii] = DiagonalOperator(self._domain[ii],
diagonal=exposure[ii])
......@@ -95,6 +95,15 @@ class ResponseOperator(LinearOperator):
self._target = self._parse_domain(target_list)
def _add_attributes_to_copy(self, copy, **kwargs):
copy._domain = self._domain
copy._target = self._target
copy._composed_kernel = self._composed_kernel.copy()
copy._composed_exposure = self._composed_exposure.copy()
copy = super(DiagonalOperator, self)._add_attributes_to_copy(copy,
**kwargs)
return copy
@property
def domain(self):
return self._domain
......
......@@ -18,6 +18,13 @@ class DirectSmoothingOperator(SmoothingOperator):
default_spaces)
self.effective_smoothing_width = 3.01
def _add_attributes_to_copy(self, copy, **kwargs):
copy.effective_smoothing_width = self.effective_smoothing_width
copy = super(DirectSmoothingOperator, self)._add_attributes_to_copy(
copy, **kwargs)
return copy
def _precompute(self, x, sigma, dxmax=None):
""" Does precomputations for Gaussian smoothing on a 1D irregular grid.
......
......@@ -19,6 +19,13 @@ class FFTSmoothingOperator(SmoothingOperator):
default_spaces=default_spaces)
self._transformator_cache = {}
def _add_attributes_to_copy(self, copy, **kwargs):
copy._transformator_cache = self._transformator_cache
copy = super(FFTSmoothingOperator, self)._add_attributes_to_copy(
copy, **kwargs)
return copy
def _smooth(self, x, spaces, inverse):
# transform to the (global-)default codomain and perform all remaining
# steps therein
......
......@@ -131,6 +131,14 @@ class SmoothingOperator(EndomorphicOperator):
self._sigma = sigma
self._log_distances = log_distances
def _add_attributes_to_copy(self, copy, **kwargs):
copy._domain = self._domain
copy._sigma = self._sigma
copy._log_distances = self._log_distances
copy = super(SmoothingOperator, self)._add_attributes_to_copy(copy,
**kwargs)
return copy
def _inverse_times(self, x, spaces):
if self.sigma == 0:
return x.copy()
......
......@@ -48,6 +48,15 @@ class SmoothnessOperator(EndomorphicOperator):
self._laplace = LaplaceOperator(domain=self.domain,
logarithmic=logarithmic)
def _add_attributes_to_copy(self, copy, **kwargs):
copy._domain = self._domain
copy._strength = self._strength
copy._laplace = self._laplace.copy()
copy = super(SmoothnessOperator, self)._add_attributes_to_copy(
copy, **kwargs)
return copy
# ---Mandatory properties and methods---
@property
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment