Commit cb10770a authored by Martin Reinecke's avatar Martin Reinecke
Browse files

try operators with fixed input domain

parent fae648b9
Pipeline #18611 passed with stage
in 4 minutes and 9 seconds
...@@ -16,17 +16,17 @@ def plot_parameters(m, t, p, p_d): ...@@ -16,17 +16,17 @@ def plot_parameters(m, t, p, p_d):
class AdjointFFTResponse(ift.LinearOperator): class AdjointFFTResponse(ift.LinearOperator):
def __init__(self, FFT, R, default_spaces=None): def __init__(self, FFT, R):
super(AdjointFFTResponse, self).__init__(default_spaces) super(AdjointFFTResponse, self).__init__()
self._domain = FFT.target self._domain = FFT.target
self._target = R.target self._target = R.target
self.R = R self.R = R
self.FFT = FFT self.FFT = FFT
def _times(self, x, spaces=None): def _times(self, x):
return self.R(self.FFT.adjoint_times(x)) return self.R(self.FFT.adjoint_times(x))
def _adjoint_times(self, x, spaces=None): def _adjoint_times(self, x):
return self.FFT(self.R.adjoint_times(x)) return self.FFT(self.R.adjoint_times(x))
@property @property
......
...@@ -35,7 +35,7 @@ if __name__ == "__main__": ...@@ -35,7 +35,7 @@ if __name__ == "__main__":
#mask.val[N10*5:N10*9, N10*5:N10*9] = 0. #mask.val[N10*5:N10*9, N10*5:N10*9] = 0.
R = ift.ResponseOperator(signal_space, sigma=(response_sigma,), exposure=(mask,)) #|\label{code:wf_response}| R = ift.ResponseOperator(signal_space, sigma=(response_sigma,), exposure=(mask,)) #|\label{code:wf_response}|
data_domain = R.target[0] data_domain = R.target[0]
R_harmonic = ift.ComposedOperator([fft, R], default_spaces=[0, 0]) R_harmonic = ift.ComposedOperator([fft, R])
# Setting up the noise covariance and drawing a random noise realization # Setting up the noise covariance and drawing a random noise realization
ndiag = ift.Field(data_domain, mock_signal.var()/signal_to_noise).weight(1) ndiag = ift.Field(data_domain, mock_signal.var()/signal_to_noise).weight(1)
......
...@@ -22,7 +22,17 @@ if __name__ == "__main__": ...@@ -22,7 +22,17 @@ if __name__ == "__main__":
signal_space_1 = ift.RGSpace([N_pixels_1], distances=L_1/N_pixels_1) signal_space_1 = ift.RGSpace([N_pixels_1], distances=L_1/N_pixels_1)
harmonic_space_1 = signal_space_1.get_default_codomain() harmonic_space_1 = signal_space_1.get_default_codomain()
fft_1 = ift.FFTOperator(harmonic_space_1, target=signal_space_1) # Setting up the geometry |\label{code:wf_geometry}|
L_2 = 2. # Total side-length of the domain
N_pixels_2 = 512 # Grid resolution (pixels per axis)
signal_space_2 = ift.RGSpace([N_pixels_2], distances=L_2/N_pixels_2)
harmonic_space_2 = signal_space_2.get_default_codomain()
signal_domain = ift.DomainTuple.make((signal_space_1, signal_space_2))
mid_domain = ift.DomainTuple.make((signal_space_1, harmonic_space_2))
harmonic_domain = ift.DomainTuple.make((harmonic_space_1, harmonic_space_2))
fft_1 = ift.FFTOperator(harmonic_domain, space=0)
power_space_1 = ift.PowerSpace(harmonic_space_1) power_space_1 = ift.PowerSpace(harmonic_space_1)
mock_power_1 = ift.Field(power_space_1, val=power_spectrum_1(power_space_1.k_lengths)) mock_power_1 = ift.Field(power_space_1, val=power_spectrum_1(power_space_1.k_lengths))
...@@ -39,13 +49,7 @@ if __name__ == "__main__": ...@@ -39,13 +49,7 @@ if __name__ == "__main__":
a = 4 * correlation_length_2 * field_variance_2**2 a = 4 * correlation_length_2 * field_variance_2**2
return a / (1 + k * correlation_length_2) ** 2.5 return a / (1 + k * correlation_length_2) ** 2.5
# Setting up the geometry |\label{code:wf_geometry}| fft_2 = ift.FFTOperator(mid_domain, space=1)
L_2 = 2. # Total side-length of the domain
N_pixels_2 = 512 # Grid resolution (pixels per axis)
signal_space_2 = ift.RGSpace([N_pixels_2], distances=L_2/N_pixels_2)
harmonic_space_2 = signal_space_2.get_default_codomain()
fft_2 = ift.FFTOperator(harmonic_space_2, target=signal_space_2)
power_space_2 = ift.PowerSpace(harmonic_space_2) power_space_2 = ift.PowerSpace(harmonic_space_2)
mock_power_2 = ift.Field(power_space_2, val=power_spectrum_2(power_space_2.k_lengths)) mock_power_2 = ift.Field(power_space_2, val=power_spectrum_2(power_space_2.k_lengths))
...@@ -73,11 +77,11 @@ if __name__ == "__main__": ...@@ -73,11 +77,11 @@ if __name__ == "__main__":
mask_2 = ift.Field(signal_space_2, val=1.) mask_2 = ift.Field(signal_space_2, val=1.)
mask_2.val[N2_10*7:N2_10*9] = 0. mask_2.val[N2_10*7:N2_10*9] = 0.
R = ift.ResponseOperator((signal_space_1, signal_space_2), R = ift.ResponseOperator(signal_domain,spaces=(0,1),
sigma=(response_sigma_1, response_sigma_2), sigma=(response_sigma_1, response_sigma_2),
exposure=(mask_1, mask_2)) #|\label{code:wf_response}| exposure=(mask_1, mask_2)) #|\label{code:wf_response}|
data_domain = R.target data_domain = R.target
R_harmonic = ift.ComposedOperator([fft, R], default_spaces=(0, 1, 0, 1)) R_harmonic = ift.ComposedOperator([fft, R])
# Setting up the noise covariance and drawing a random noise realization # Setting up the noise covariance and drawing a random noise realization
ndiag = ift.Field(data_domain, mock_signal.var()/signal_to_noise).weight(1) ndiag = ift.Field(data_domain, mock_signal.var()/signal_to_noise).weight(1)
......
...@@ -34,7 +34,7 @@ if __name__ == "__main__": ...@@ -34,7 +34,7 @@ if __name__ == "__main__":
mask.val[N10*5:N10*9, N10*5:N10*9] = 0. mask.val[N10*5:N10*9, N10*5:N10*9] = 0.
R = ift.ResponseOperator(signal_space, sigma=(response_sigma,), exposure=(mask,)) #|\label{code:wf_response}| R = ift.ResponseOperator(signal_space, sigma=(response_sigma,), exposure=(mask,)) #|\label{code:wf_response}|
data_domain = R.target[0] data_domain = R.target[0]
R_harmonic = ift.ComposedOperator([fft, R], default_spaces=[0, 0]) R_harmonic = ift.ComposedOperator([fft, R])
# Setting up the noise covariance and drawing a random noise realization # Setting up the noise covariance and drawing a random noise realization
ndiag = ift.Field(data_domain, mock_signal.var()/signal_to_noise).weight(1) ndiag = ift.Field(data_domain, mock_signal.var()/signal_to_noise).weight(1)
......
...@@ -45,7 +45,7 @@ if __name__ == "__main__": ...@@ -45,7 +45,7 @@ if __name__ == "__main__":
R = ift.ResponseOperator(signal_space, sigma=(response_sigma,)) R = ift.ResponseOperator(signal_space, sigma=(response_sigma,))
data_domain = R.target[0] data_domain = R.target[0]
R_harmonic = ift.ComposedOperator([fft, R], default_spaces=[0, 0]) R_harmonic = ift.ComposedOperator([fft, R])
N = ift.DiagonalOperator(ift.Field(data_domain,mock_signal.var()/signal_to_noise).weight(1)) N = ift.DiagonalOperator(ift.Field(data_domain,mock_signal.var()/signal_to_noise).weight(1))
noise = ift.Field.from_random(domain=data_domain, noise = ift.Field.from_random(domain=data_domain,
......
...@@ -5,17 +5,17 @@ np.random.seed(42) ...@@ -5,17 +5,17 @@ np.random.seed(42)
class AdjointFFTResponse(ift.LinearOperator): class AdjointFFTResponse(ift.LinearOperator):
def __init__(self, FFT, R, default_spaces=None): def __init__(self, FFT, R):
super(AdjointFFTResponse, self).__init__(default_spaces) super(AdjointFFTResponse, self).__init__()
self._domain = FFT.target self._domain = FFT.target
self._target = R.target self._target = R.target
self.R = R self.R = R
self.FFT = FFT self.FFT = FFT
def _times(self, x, spaces=None): def _times(self, x):
return self.R(self.FFT.adjoint_times(x)) return self.R(self.FFT.adjoint_times(x))
def _adjoint_times(self, x, spaces=None): def _adjoint_times(self, x):
return self.FFT(self.R.adjoint_times(x)) return self.FFT(self.R.adjoint_times(x))
@property @property
......
...@@ -104,3 +104,9 @@ class DomainTuple(object): ...@@ -104,3 +104,9 @@ class DomainTuple(object):
if self is x: if self is x:
return False return False
return self._dom != x._dom return self._dom != x._dom
def __str__(self):
res = "DomainTuple, len: " + str(len(self.domains))
for i in self.domains:
res += "\n" + str(i)
return res
...@@ -510,8 +510,9 @@ class Field(object): ...@@ -510,8 +510,9 @@ class Field(object):
# create a diagonal operator which is capable of taking care of the # create a diagonal operator which is capable of taking care of the
# axes-matching # axes-matching
from .operators.diagonal_operator import DiagonalOperator from .operators.diagonal_operator import DiagonalOperator
diag = DiagonalOperator(y.conjugate(), copy=False) diag = DiagonalOperator(y.conjugate(), self.domain,
dotted = diag(x, spaces=spaces) spaces=spaces, copy=False)
dotted = diag(x)
return fct*dotted.sum(spaces=spaces) return fct*dotted.sum(spaces=spaces)
def norm(self): def norm(self):
......
...@@ -32,7 +32,7 @@ class CriticalPowerCurvature(InvertibleOperatorMixin, EndomorphicOperator): ...@@ -32,7 +32,7 @@ class CriticalPowerCurvature(InvertibleOperatorMixin, EndomorphicOperator):
preconditioner=preconditioner, preconditioner=preconditioner,
**kwargs) **kwargs)
def _times(self, x, spaces): def _times(self, x):
return self.T(x) + self.theta(x) return self.T(x) + self.theta(x)
# ---Mandatory properties and methods--- # ---Mandatory properties and methods---
......
...@@ -58,7 +58,7 @@ class LogNormalWienerFilterCurvature(InvertibleOperatorMixin, ...@@ -58,7 +58,7 @@ class LogNormalWienerFilterCurvature(InvertibleOperatorMixin,
# ---Added properties and methods--- # ---Added properties and methods---
def _times(self, x, spaces): def _times(self, x):
part1 = self.S.inverse_times(x) part1 = self.S.inverse_times(x)
# part2 = self._exppRNRexppd * x # part2 = self._exppRNRexppd * x
part3 = self._fft.adjoint_times(self._expp_sspace * self._fft(x)) part3 = self._fft.adjoint_times(self._expp_sspace * self._fft(x))
......
...@@ -48,7 +48,7 @@ class WienerFilterCurvature(InvertibleOperatorMixin, EndomorphicOperator): ...@@ -48,7 +48,7 @@ class WienerFilterCurvature(InvertibleOperatorMixin, EndomorphicOperator):
# ---Added properties and methods--- # ---Added properties and methods---
def _times(self, x, spaces): def _times(self, x):
res = self.R.adjoint_times(self.N.inverse_times(self.R(x))) res = self.R.adjoint_times(self.N.inverse_times(self.R(x)))
res += self.S.inverse_times(x) res += self.S.inverse_times(x)
return res return res
...@@ -29,9 +29,6 @@ class ComposedOperator(LinearOperator): ...@@ -29,9 +29,6 @@ class ComposedOperator(LinearOperator):
---------- ----------
operators : tuple of NIFTy Operators operators : tuple of NIFTy Operators
The tuple of LinearOperators. The tuple of LinearOperators.
default_spaces : tuple of ints *optional*
Defines on which space(s) of a given field the Operator acts by
default (default: None)
Attributes Attributes
...@@ -48,7 +45,7 @@ class ComposedOperator(LinearOperator): ...@@ -48,7 +45,7 @@ class ComposedOperator(LinearOperator):
TypeError TypeError
Raised if Raised if
* an element of the operator list is not an instance of the * an element of the operator list is not an instance of the
LinearOperator-baseclass. LinearOperator base class.
Notes Notes
----- -----
...@@ -64,8 +61,8 @@ class ComposedOperator(LinearOperator): ...@@ -64,8 +61,8 @@ class ComposedOperator(LinearOperator):
>>> x2 = RGSpace(10) >>> x2 = RGSpace(10)
>>> k1 = RGRGTransformation.get_codomain(x1) >>> k1 = RGRGTransformation.get_codomain(x1)
>>> k2 = RGRGTransformation.get_codomain(x2) >>> k2 = RGRGTransformation.get_codomain(x2)
>>> FFT1 = FFTOperator(domain=x1, target=k1) >>> FFT1 = FFTOperator(domain=(x1,x2), target=(k1,x2), space=0)
>>> FFT2 = FFTOperator(domain=x2, target=k2) >>> FFT2 = FFTOperator(domain=(k1,x2), target=(k1,k2), space=1)
>>> FFT = ComposedOperator((FFT1, FFT2) >>> FFT = ComposedOperator((FFT1, FFT2)
>>> f = Field.from_random('normal', domain=(x1,x2)) >>> f = Field.from_random('normal', domain=(x1,x2))
>>> FFT.times(f) >>> FFT.times(f)
...@@ -73,80 +70,50 @@ class ComposedOperator(LinearOperator): ...@@ -73,80 +70,50 @@ class ComposedOperator(LinearOperator):
""" """
# ---Overwritten properties and methods--- # ---Overwritten properties and methods---
def __init__(self, operators, default_spaces=None): def __init__(self, operators):
super(ComposedOperator, self).__init__(default_spaces) super(ComposedOperator, self).__init__()
for i in range(1, len(operators)):
if operators[i].domain != operators[i-1].target:
raise ValueError("incompatible domains")
self._operator_store = () self._operator_store = ()
for op in operators: for op in operators:
if not isinstance(op, LinearOperator): if not isinstance(op, LinearOperator):
raise TypeError("The elements of the operator list must be" raise TypeError("The elements of the operator list must be"
"instances of the LinearOperator-baseclass") "instances of the LinearOperator base class")
self._operator_store += (op,) self._operator_store += (op,)
def _check_input_compatibility(self, x, spaces, inverse=False):
"""
The input check must be disabled for the ComposedOperator, since it
is not easily forecasteable what the output of an operator-call
will look like.
"""
if spaces is None:
spaces = self.default_spaces
return spaces
# ---Mandatory properties and methods--- # ---Mandatory properties and methods---
@property @property
def domain(self): def domain(self):
if not hasattr(self, '_domain'): return self._operator_store[0].domain
dom = ()
for op in self._operator_store:
dom += op.domain.domains
self._domain = DomainTuple.make(dom)
return self._domain
@property @property
def target(self): def target(self):
if not hasattr(self, '_target'): return self._operator_store[-1].target
tgt = ()
for op in self._operator_store:
tgt += op.target.domains
self._target = DomainTuple.make(tgt)
return self._target
@property @property
def unitary(self): def unitary(self):
return False return False
def _times(self, x, spaces): def _times(self, x):
return self._times_helper(x, spaces, func='times') return self._times_helper(x, func='times')
def _adjoint_times(self, x, spaces): def _adjoint_times(self, x):
return self._inverse_times_helper(x, spaces, func='adjoint_times') return self._inverse_times_helper(x, func='adjoint_times')
def _inverse_times(self, x, spaces): def _inverse_times(self, x):
return self._inverse_times_helper(x, spaces, func='inverse_times') return self._inverse_times_helper(x, func='inverse_times')
def _adjoint_inverse_times(self, x, spaces): def _adjoint_inverse_times(self, x):
return self._times_helper(x, spaces, func='adjoint_inverse_times') return self._times_helper(x, func='adjoint_inverse_times')
def _times_helper(self, x, spaces, func): def _times_helper(self, x, func):
space_index = 0
if spaces is None:
spaces = range(len(self.domain))
for op in self._operator_store: for op in self._operator_store:
active_spaces = spaces[space_index:space_index+len(op.domain)] x = getattr(op, func)(x)
space_index += len(op.domain)
x = getattr(op, func)(x, spaces=active_spaces)
return x return x
def _inverse_times_helper(self, x, spaces, func): def _inverse_times_helper(self, x, func):
space_index = 0
if spaces is None:
spaces = range(len(self.target))
rev_spaces = spaces[::-1]
for op in reversed(self._operator_store): for op in reversed(self._operator_store):
active_spaces = rev_spaces[space_index:space_index+len(op.target)] x = getattr(op, func)(x)
space_index += len(op.target)
x = getattr(op, func)(x, spaces=active_spaces[::-1])
return x return x
...@@ -23,7 +23,7 @@ import numpy as np ...@@ -23,7 +23,7 @@ import numpy as np
from ...field import Field from ...field import Field
from ...domain_tuple import DomainTuple from ...domain_tuple import DomainTuple
from ..endomorphic_operator import EndomorphicOperator from ..endomorphic_operator import EndomorphicOperator
from ...nifty_utilities import cast_iseq_to_tuple
class DiagonalOperator(EndomorphicOperator): class DiagonalOperator(EndomorphicOperator):
""" NIFTY class for diagonal operators. """ NIFTY class for diagonal operators.
...@@ -39,9 +39,6 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -39,9 +39,6 @@ class DiagonalOperator(EndomorphicOperator):
The diagonal entries of the operator. The diagonal entries of the operator.
copy : boolean copy : boolean
Internal copy of the diagonal (default: True) Internal copy of the diagonal (default: True)
default_spaces : tuple of ints *optional*
Defines on which space(s) of a given field the Operator acts by
default (default: None)
Attributes Attributes
---------- ----------
...@@ -55,9 +52,6 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -55,9 +52,6 @@ class DiagonalOperator(EndomorphicOperator):
self_adjoint : boolean self_adjoint : boolean
Indicates whether the operator is self_adjoint or not. Indicates whether the operator is self_adjoint or not.
Raises
------
See Also See Also
-------- --------
EndomorphicOperator EndomorphicOperator
...@@ -66,30 +60,48 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -66,30 +60,48 @@ class DiagonalOperator(EndomorphicOperator):
# ---Overwritten properties and methods--- # ---Overwritten properties and methods---
def __init__(self, diagonal, copy=True, default_spaces=None): def __init__(self, diagonal, domain=None, spaces=None, copy=True):
super(DiagonalOperator, self).__init__(default_spaces) super(DiagonalOperator, self).__init__()
if not isinstance(diagonal, Field): if not isinstance(diagonal, Field):
raise TypeError("Field object required") raise TypeError("Field object required")
if domain is None:
self._domain = diagonal.domain
else:
self._domain = DomainTuple.make(domain)
if spaces is None:
self._spaces = None
if diagonal.domain != self._domain:
raise ValueError("domain mismatch")
else:
self._spaces = cast_iseq_to_tuple(spaces)
nspc = len(self._spaces)
if nspc != len(diagonal.domain.domains):
raise ValueError("spaces and domain must have the same length")
if nspc > len(self._domain.domains):
raise ValueError("too many spaces")
if nspc > len(set(self._spaces)):
raise ValueError("non-unique space indices")
# if nspc==len(self.diagonal.domain.domains, we could do some optimization
for i, j in enumerate(self._spaces):
if diagonal.domain[i] != self._domain[j]:
raise ValueError("domain mismatch")
self._diagonal = diagonal if not copy else diagonal.copy() self._diagonal = diagonal if not copy else diagonal.copy()
self._self_adjoint = None self._self_adjoint = None
self._unitary = None self._unitary = None
def _times(self, x, spaces): def _times(self, x):
return self._times_helper(x, spaces, operation=lambda z: z.__mul__) return self._times_helper(x, lambda z: z.__mul__)
def _adjoint_times(self, x, spaces): def _adjoint_times(self, x):
return self._times_helper(x, spaces, return self._times_helper(x, lambda z: z.conjugate().__mul__)
operation=lambda z: z.conjugate().__mul__)
def _inverse_times(self, x, spaces): def _inverse_times(self, x):
return self._times_helper(x, spaces, return self._times_helper(x, lambda z: z.__rtruediv__)
operation=lambda z: z.__rtruediv__)
def _adjoint_inverse_times(self, x, spaces): def _adjoint_inverse_times(self, x):
return self._times_helper(x, spaces, return self._times_helper(x, lambda z: z.conjugate().__rtruediv__)
operation=lambda z:
z.conjugate().__rtruediv__)
def diagonal(self, copy=True): def diagonal(self, copy=True):
""" Returns the diagonal of the Operator. """ Returns the diagonal of the Operator.
...@@ -111,7 +123,7 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -111,7 +123,7 @@ class DiagonalOperator(EndomorphicOperator):
@property @property
def domain(self): def domain(self):
return self._diagonal.domain return self._domain
@property @property
def self_adjoint(self): def self_adjoint(self):
...@@ -130,18 +142,12 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -130,18 +142,12 @@ class DiagonalOperator(EndomorphicOperator):
# ---Added properties and methods--- # ---Added properties and methods---
def _times_helper(self, x, spaces, operation): def _times_helper(self, x, operation):
# if the domain matches directly if self._spaces is None:
# -> multiply the fields directly
if x.domain == self.domain:
# here the actual multiplication takes place
return operation(self._diagonal)(x) return operation(self._diagonal)(x)
if spaces is None:
active_axes = range(len(x.shape))
else:
active_axes = [] active_axes = []
for space_index in spaces: for space_index in self._spaces:
active_axes += x.domain.axes[space_index] active_axes += x.domain.axes[space_index]
reshaper = [x.shape[i] if i in active_axes else 1 reshaper = [x.shape[i] if i in active_axes else 1
......
...@@ -28,12 +28,6 @@ class EndomorphicOperator(LinearOperator): ...@@ -28,12 +28,6 @@ class EndomorphicOperator(LinearOperator):
LinearOperator. By definition, domain and target are the same in LinearOperator. By definition, domain and target are the same in
EndomorphicOperator. EndomorphicOperator.
Parameters
----------