Skip to content
Snippets Groups Projects
Commit 664d97fb authored by Theo Steininger's avatar Theo Steininger
Browse files

Added copy method to Operators.

parent baa71d88
No related branches found
No related tags found
1 merge request!196Nightly
Pipeline #
Showing
with 209 additions and 18 deletions
......@@ -33,6 +33,23 @@ class CriticalPowerCurvature(InvertibleOperatorMixin, EndomorphicOperator):
preconditioner=preconditioner,
**kwargs)
def _add_attributes_to_copy(self, copy, **kwargs):
copy._domain = self._domain
if 'theta' in kwargs:
theta = kwargs['theta']
copy.theta = DiagonalOperator(theta.domain, diagonal=theta)
else:
copy.theta = self.theta.copy()
if 'T' in kwargs:
copy.T = kwargs['T']
else:
copy.T = self.T
copy = super(CriticalPowerCurvature,
self)._add_attributes_to_copy(copy, **kwargs)
return copy
def _times(self, x, spaces):
return self.T(x) + self.theta(x)
......
......@@ -54,7 +54,8 @@ class CriticalPowerEnergy(Energy):
# ---Overwritten properties and methods---
def __init__(self, position, m, D=None, alpha=1.0, q=0.,
smoothness_prior=0., logarithmic=True, samples=3, w=None):
smoothness_prior=0., logarithmic=True, samples=3, w=None,
old_curvature=None):
super(CriticalPowerEnergy, self).__init__(position=position)
self.m = m
self.D = D
......@@ -66,6 +67,8 @@ class CriticalPowerEnergy(Energy):
logarithmic=logarithmic)
self.rho = self.position.domain[0].rho
self._w = w if w is not None else None
self._old_curvature = old_curvature
self._curvature = None
# ---Mandatory properties and methods---
......@@ -73,9 +76,11 @@ class CriticalPowerEnergy(Energy):
return self.__class__(position, self.m, D=self.D, alpha=self.alpha,
q=self.q, smoothness_prior=self.smoothness_prior,
logarithmic=self.logarithmic,
w=self.w, samples=self.samples)
w=self.w, samples=self.samples,
old_curvature=self._curvature)
@property
@memo
def value(self):
energy = self._theta.sum()
energy += self.position.vdot(self._rho_prime, bare=True)
......@@ -83,6 +88,7 @@ class CriticalPowerEnergy(Energy):
return energy.real
@property
@memo
def gradient(self):
gradient = -self._theta.weight(-1)
gradient += (self._rho_prime).weight(-1)
......@@ -92,9 +98,14 @@ class CriticalPowerEnergy(Energy):
@property
def curvature(self):
curvature = CriticalPowerCurvature(theta=self._theta.weight(-1),
T=self.T)
return curvature
if self._curvature is None:
if self._old_curvature is None:
self._curvature = CriticalPowerCurvature(
theta=self._theta.weight(-1), T=self.T)
else:
self._curvature = self._old_curvature.copy(
theta=self._theta.weight(-1), T=self.T)
return self._curvature
# ---Added properties and methods---
......
......@@ -48,6 +48,23 @@ class LogNormalWienerFilterCurvature(InvertibleOperatorMixin,
preconditioner=preconditioner,
**kwargs)
def _add_attributes_to_copy(self, copy, **kwargs):
copy._cache = {}
copy._domain = self._domain
copy.R = self.R.copy()
copy.N = self.N.copy()
copy.S = self.S.copy()
copy.d = self.d.copy()
if 'position' in kwargs:
copy.position = kwargs['position']
else:
copy.position = self.position.copy()
copy._fft = self._fft
copy = super(LogNormalWienerFilterCurvature,
self)._add_attributes_to_copy(copy, **kwargs)
return copy
@property
def domain(self):
return self._domain
......
......@@ -24,7 +24,7 @@ class LogNormalWienerFilterEnergy(Energy):
The prior signal covariance in harmonic space.
"""
def __init__(self, position, d, R, N, S, fft4exp=None):
def __init__(self, position, d, R, N, S, fft4exp=None, old_curvature=None):
super(LogNormalWienerFilterEnergy, self).__init__(position=position)
self.d = d
self.R = R
......@@ -37,9 +37,13 @@ class LogNormalWienerFilterEnergy(Energy):
else:
self._fft = fft4exp
self._old_curvature = old_curvature
self._curvature = None
def at(self, position):
return self.__class__(position=position, d=self.d, R=self.R, N=self.N,
S=self.S, fft4exp=self._fft)
S=self.S, fft4exp=self._fft,
old_curvature=self._curvature)
@property
@memo
......@@ -53,11 +57,20 @@ class LogNormalWienerFilterEnergy(Energy):
return self._Sp + self._exppRNRexppd
@property
@memo
def curvature(self):
return LogNormalWienerFilterCurvature(R=self.R, N=self.N, S=self.S,
d=self.d, position=self.position,
fft4exp=self._fft)
if self._curvature is None:
if self._old_curvature is None:
self._curvature = LogNormalWienerFilterCurvature(
R=self.R,
N=self.N,
S=self.S,
d=self.d,
position=self.position,
fft4exp=self._fft)
else:
self._curvature = \
self._old_curvature.copy(position=self.position)
return self._curvature
@property
def _expp(self):
......
......@@ -35,6 +35,15 @@ class WienerFilterCurvature(InvertibleOperatorMixin, EndomorphicOperator):
preconditioner=preconditioner,
**kwargs)
def _add_attributes_to_copy(self, copy, **kwargs):
copy._domain = self._domain
copy.R = self.R.copy()
copy.N = self.N.copy()
copy.S = self.S.copy()
copy = super(WienerFilterCurvature, self)._add_attributes_to_copy(
copy, **kwargs)
return copy
@property
def domain(self):
return self._domain
......
......@@ -23,16 +23,17 @@ class WienerFilterEnergy(Energy):
The prior signal covariance in harmonic space.
"""
def __init__(self, position, d, R, N, S):
def __init__(self, position, d, R, N, S, old_curvature=None):
super(WienerFilterEnergy, self).__init__(position=position)
self.d = d
self.R = R
self.N = N
self.S = S
self._curvature = old_curvature
def at(self, position):
return self.__class__(position=position, d=self.d, R=self.R, N=self.N,
S=self.S)
S=self.S, old_curvature=self.curvature)
@property
@memo
......@@ -45,9 +46,12 @@ class WienerFilterEnergy(Energy):
return self._Dx - self._j
@property
@memo
def curvature(self):
return WienerFilterCurvature(R=self.R, N=self.N, S=self.S)
if self._curvature is None:
self._curvature = WienerFilterCurvature(R=self.R,
N=self.N,
S=self.S)
return self._curvature
@property
@memo
......
......@@ -91,6 +91,14 @@ class ComposedOperator(LinearOperator):
"instances of the LinearOperator-baseclass")
self._operator_store += (op,)
def _add_attributes_to_copy(self, copy, **kwargs):
copy._operator_store = ()
for op in self._operator_store:
copy._operator_store += (op.copy(),)
copy = super(ComposedOperator, self)._add_attributes_to_copy(
copy, **kwargs)
return copy
def _check_input_compatibility(self, x, spaces, inverse=False):
"""
The input check must be disabled for the ComposedOperator, since it
......
......@@ -117,8 +117,20 @@ class DiagonalOperator(EndomorphicOperator):
distribution_strategy=distribution_strategy,
val=diagonal)
self._self_adjoint = None
self._unitary = None
self.set_diagonal(diagonal=diagonal, bare=bare, copy=copy)
def _add_attributes_to_copy(self, copy, **kwargs):
copy._domain = self._domain
copy._distribution_strategy = self._distribution_strategy
copy.set_diagonal(diagonal=self.diagonal(bare=True), bare=True)
copy._self_adjoint = self._self_adjoint
copy._unitary = self._unitary
copy = super(DiagonalOperator, self)._add_attributes_to_copy(copy,
**kwargs)
return copy
def _times(self, x, spaces):
return self._times_helper(x, spaces, operation=lambda z: z.__mul__)
......@@ -127,7 +139,8 @@ class DiagonalOperator(EndomorphicOperator):
operation=lambda z: z.adjoint().__mul__)
def _inverse_times(self, x, spaces):
return self._times_helper(x, spaces, operation=lambda z: z.__rtruediv__)
return self._times_helper(
x, spaces, operation=lambda z: z.__rtruediv__)
def _adjoint_inverse_times(self, x, spaces):
return self._times_helper(x, spaces,
......
......@@ -148,6 +148,17 @@ class FFTOperator(LinearOperator):
self.target_dtype = \
None if target_dtype is None else np.dtype(target_dtype)
def _add_attributes_to_copy(self, copy, **kwargs):
copy._domain = self._domain
copy._target = self._target
copy._forward_transformation = self._forward_transformation
copy._backward_transformation = self._backward_transformation
copy.domain_dtype = self.domain_dtype
copy.target_dtype = self.target_dtype
copy = super(FFTOperator, self)._add_attributes_to_copy(copy,
**kwargs)
return copy
def _times(self, x, spaces):
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
if spaces is None:
......
......@@ -73,6 +73,23 @@ class InvertibleOperatorMixin(object):
self.__backward_x0 = backward_x0
super(InvertibleOperatorMixin, self).__init__(*args, **kwargs)
def _add_attributes_to_copy(self, copy, **kwargs):
copy.__preconditioner = self.__preconditioner
copy.__inverter = self.__inverter
try:
copy.__forward_x0 = self.__forward_x0.copy()
except AttributeError:
copy.__forward_x0 = self.__forward_x0
try:
copy.__backward_x0 = self.__backward_x0.copy()
except AttributeError:
copy.__backward_x0 = self.__backward_x0
copy = super(InvertibleOperatorMixin, self)._add_attributes_to_copy(
copy, **kwargs)
return copy
def _times(self, x, spaces):
if self.__forward_x0 is not None:
x0 = self.__forward_x0
......
......@@ -64,6 +64,15 @@ class LaplaceOperator(EndomorphicOperator):
self._dposc[1:] += self._dpos
self._dposc *= 0.5
def _add_attributes_to_copy(self, copy, **kwargs):
copy._domain = self._domain
copy._logarithmic = self._logarithmic
copy._dpos = self._dpos
copy._dposc = self._dposc
copy = super(LaplaceOperator, self)._add_attributes_to_copy(copy,
**kwargs)
return copy
@property
def target(self):
return self._domain
......
......@@ -26,7 +26,8 @@ from ... import nifty_utilities as utilities
from future.utils import with_metaclass
class LinearOperator(with_metaclass(NiftyMeta, type('NewBase', (Loggable, object), {}))):
class LinearOperator(
with_metaclass(NiftyMeta, type('NewBase', (Loggable, object), {}))):
"""NIFTY base class for linear operators.
The base NIFTY operator class is an abstract class from which
......@@ -75,6 +76,20 @@ class LinearOperator(with_metaclass(NiftyMeta, type('NewBase', (Loggable, object
def __init__(self, default_spaces=None):
self._default_spaces = default_spaces
def copy(self, **kwargs):
class EmptyCopy(self.__class__):
def __init__(self):
pass
result = EmptyCopy()
result.__class__ = self.__class__
result = self._add_attributes_to_copy(result, **kwargs)
return result
def _add_attributes_to_copy(self, copy, **kwargs):
copy._default_spaces = self.default_spaces
return copy
@staticmethod
def _parse_domain(domain):
return utilities.parse_domain(domain)
......
......@@ -87,6 +87,13 @@ class ProjectionOperator(EndomorphicOperator):
self._projection_field = projection_field
self._unitary = None
def _add_attributes_to_copy(self, copy, **kwargs):
copy._projection_field = self._projection_field
copy._unitary = self._unitary
copy = super(ProjectionOperator, self)._add_attributes_to_copy(
copy, **kwargs)
return copy
def _times(self, x, spaces):
# if the domain matches directly
# -> multiply the fields directly
......
......@@ -82,7 +82,7 @@ class ResponseOperator(LinearOperator):
for ii in range(len(kernel_smoothing)):
kernel_smoothing[ii] = SmoothingOperator.make(self._domain[ii],
sigma=sigma[ii])
sigma=sigma[ii])
kernel_exposure[ii] = DiagonalOperator(self._domain[ii],
diagonal=exposure[ii])
......@@ -95,6 +95,15 @@ class ResponseOperator(LinearOperator):
self._target = self._parse_domain(target_list)
def _add_attributes_to_copy(self, copy, **kwargs):
copy._domain = self._domain
copy._target = self._target
copy._composed_kernel = self._composed_kernel.copy()
copy._composed_exposure = self._composed_exposure.copy()
copy = super(DiagonalOperator, self)._add_attributes_to_copy(copy,
**kwargs)
return copy
@property
def domain(self):
return self._domain
......
......@@ -18,6 +18,13 @@ class DirectSmoothingOperator(SmoothingOperator):
default_spaces)
self.effective_smoothing_width = 3.01
def _add_attributes_to_copy(self, copy, **kwargs):
copy.effective_smoothing_width = self.effective_smoothing_width
copy = super(DirectSmoothingOperator, self)._add_attributes_to_copy(
copy, **kwargs)
return copy
def _precompute(self, x, sigma, dxmax=None):
""" Does precomputations for Gaussian smoothing on a 1D irregular grid.
......
......@@ -19,6 +19,13 @@ class FFTSmoothingOperator(SmoothingOperator):
default_spaces=default_spaces)
self._transformator_cache = {}
def _add_attributes_to_copy(self, copy, **kwargs):
copy._transformator_cache = self._transformator_cache
copy = super(FFTSmoothingOperator, self)._add_attributes_to_copy(
copy, **kwargs)
return copy
def _smooth(self, x, spaces, inverse):
# transform to the (global-)default codomain and perform all remaining
# steps therein
......
......@@ -131,6 +131,14 @@ class SmoothingOperator(EndomorphicOperator):
self._sigma = sigma
self._log_distances = log_distances
def _add_attributes_to_copy(self, copy, **kwargs):
copy._domain = self._domain
copy._sigma = self._sigma
copy._log_distances = self._log_distances
copy = super(SmoothingOperator, self)._add_attributes_to_copy(copy,
**kwargs)
return copy
def _inverse_times(self, x, spaces):
if self.sigma == 0:
return x.copy()
......
......@@ -48,6 +48,15 @@ class SmoothnessOperator(EndomorphicOperator):
self._laplace = LaplaceOperator(domain=self.domain,
logarithmic=logarithmic)
def _add_attributes_to_copy(self, copy, **kwargs):
copy._domain = self._domain
copy._strength = self._strength
copy._laplace = self._laplace.copy()
copy = super(SmoothnessOperator, self)._add_attributes_to_copy(
copy, **kwargs)
return copy
# ---Mandatory properties and methods---
@property
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment