Commit cb10770a authored by Martin Reinecke's avatar Martin Reinecke

try operators with fixed input domain

parent fae648b9
Pipeline #18611 passed with stage
in 4 minutes and 9 seconds
......@@ -16,17 +16,17 @@ def plot_parameters(m, t, p, p_d):
class AdjointFFTResponse(ift.LinearOperator):
def __init__(self, FFT, R, default_spaces=None):
super(AdjointFFTResponse, self).__init__(default_spaces)
def __init__(self, FFT, R):
super(AdjointFFTResponse, self).__init__()
self._domain = FFT.target
self._target = R.target
self.R = R
self.FFT = FFT
def _times(self, x, spaces=None):
def _times(self, x):
return self.R(self.FFT.adjoint_times(x))
def _adjoint_times(self, x, spaces=None):
def _adjoint_times(self, x):
return self.FFT(self.R.adjoint_times(x))
@property
......
......@@ -35,7 +35,7 @@ if __name__ == "__main__":
#mask.val[N10*5:N10*9, N10*5:N10*9] = 0.
R = ift.ResponseOperator(signal_space, sigma=(response_sigma,), exposure=(mask,)) #|\label{code:wf_response}|
data_domain = R.target[0]
R_harmonic = ift.ComposedOperator([fft, R], default_spaces=[0, 0])
R_harmonic = ift.ComposedOperator([fft, R])
# Setting up the noise covariance and drawing a random noise realization
ndiag = ift.Field(data_domain, mock_signal.var()/signal_to_noise).weight(1)
......
......@@ -22,7 +22,17 @@ if __name__ == "__main__":
signal_space_1 = ift.RGSpace([N_pixels_1], distances=L_1/N_pixels_1)
harmonic_space_1 = signal_space_1.get_default_codomain()
fft_1 = ift.FFTOperator(harmonic_space_1, target=signal_space_1)
# Setting up the geometry |\label{code:wf_geometry}|
L_2 = 2. # Total side-length of the domain
N_pixels_2 = 512 # Grid resolution (pixels per axis)
signal_space_2 = ift.RGSpace([N_pixels_2], distances=L_2/N_pixels_2)
harmonic_space_2 = signal_space_2.get_default_codomain()
signal_domain = ift.DomainTuple.make((signal_space_1, signal_space_2))
mid_domain = ift.DomainTuple.make((signal_space_1, harmonic_space_2))
harmonic_domain = ift.DomainTuple.make((harmonic_space_1, harmonic_space_2))
fft_1 = ift.FFTOperator(harmonic_domain, space=0)
power_space_1 = ift.PowerSpace(harmonic_space_1)
mock_power_1 = ift.Field(power_space_1, val=power_spectrum_1(power_space_1.k_lengths))
......@@ -39,13 +49,7 @@ if __name__ == "__main__":
a = 4 * correlation_length_2 * field_variance_2**2
return a / (1 + k * correlation_length_2) ** 2.5
# Setting up the geometry |\label{code:wf_geometry}|
L_2 = 2. # Total side-length of the domain
N_pixels_2 = 512 # Grid resolution (pixels per axis)
signal_space_2 = ift.RGSpace([N_pixels_2], distances=L_2/N_pixels_2)
harmonic_space_2 = signal_space_2.get_default_codomain()
fft_2 = ift.FFTOperator(harmonic_space_2, target=signal_space_2)
fft_2 = ift.FFTOperator(mid_domain, space=1)
power_space_2 = ift.PowerSpace(harmonic_space_2)
mock_power_2 = ift.Field(power_space_2, val=power_spectrum_2(power_space_2.k_lengths))
......@@ -73,11 +77,11 @@ if __name__ == "__main__":
mask_2 = ift.Field(signal_space_2, val=1.)
mask_2.val[N2_10*7:N2_10*9] = 0.
R = ift.ResponseOperator((signal_space_1, signal_space_2),
R = ift.ResponseOperator(signal_domain,spaces=(0,1),
sigma=(response_sigma_1, response_sigma_2),
exposure=(mask_1, mask_2)) #|\label{code:wf_response}|
data_domain = R.target
R_harmonic = ift.ComposedOperator([fft, R], default_spaces=(0, 1, 0, 1))
R_harmonic = ift.ComposedOperator([fft, R])
# Setting up the noise covariance and drawing a random noise realization
ndiag = ift.Field(data_domain, mock_signal.var()/signal_to_noise).weight(1)
......
......@@ -34,7 +34,7 @@ if __name__ == "__main__":
mask.val[N10*5:N10*9, N10*5:N10*9] = 0.
R = ift.ResponseOperator(signal_space, sigma=(response_sigma,), exposure=(mask,)) #|\label{code:wf_response}|
data_domain = R.target[0]
R_harmonic = ift.ComposedOperator([fft, R], default_spaces=[0, 0])
R_harmonic = ift.ComposedOperator([fft, R])
# Setting up the noise covariance and drawing a random noise realization
ndiag = ift.Field(data_domain, mock_signal.var()/signal_to_noise).weight(1)
......
......@@ -45,7 +45,7 @@ if __name__ == "__main__":
R = ift.ResponseOperator(signal_space, sigma=(response_sigma,))
data_domain = R.target[0]
R_harmonic = ift.ComposedOperator([fft, R], default_spaces=[0, 0])
R_harmonic = ift.ComposedOperator([fft, R])
N = ift.DiagonalOperator(ift.Field(data_domain,mock_signal.var()/signal_to_noise).weight(1))
noise = ift.Field.from_random(domain=data_domain,
......
......@@ -5,17 +5,17 @@ np.random.seed(42)
class AdjointFFTResponse(ift.LinearOperator):
def __init__(self, FFT, R, default_spaces=None):
super(AdjointFFTResponse, self).__init__(default_spaces)
def __init__(self, FFT, R):
super(AdjointFFTResponse, self).__init__()
self._domain = FFT.target
self._target = R.target
self.R = R
self.FFT = FFT
def _times(self, x, spaces=None):
def _times(self, x):
return self.R(self.FFT.adjoint_times(x))
def _adjoint_times(self, x, spaces=None):
def _adjoint_times(self, x):
return self.FFT(self.R.adjoint_times(x))
@property
......
......@@ -104,3 +104,9 @@ class DomainTuple(object):
if self is x:
return False
return self._dom != x._dom
def __str__(self):
res = "DomainTuple, len: " + str(len(self.domains))
for i in self.domains:
res += "\n" + str(i)
return res
......@@ -510,8 +510,9 @@ class Field(object):
# create a diagonal operator which is capable of taking care of the
# axes-matching
from .operators.diagonal_operator import DiagonalOperator
diag = DiagonalOperator(y.conjugate(), copy=False)
dotted = diag(x, spaces=spaces)
diag = DiagonalOperator(y.conjugate(), self.domain,
spaces=spaces, copy=False)
dotted = diag(x)
return fct*dotted.sum(spaces=spaces)
def norm(self):
......
......@@ -32,7 +32,7 @@ class CriticalPowerCurvature(InvertibleOperatorMixin, EndomorphicOperator):
preconditioner=preconditioner,
**kwargs)
def _times(self, x, spaces):
def _times(self, x):
return self.T(x) + self.theta(x)
# ---Mandatory properties and methods---
......
......@@ -58,7 +58,7 @@ class LogNormalWienerFilterCurvature(InvertibleOperatorMixin,
# ---Added properties and methods---
def _times(self, x, spaces):
def _times(self, x):
part1 = self.S.inverse_times(x)
# part2 = self._exppRNRexppd * x
part3 = self._fft.adjoint_times(self._expp_sspace * self._fft(x))
......
......@@ -48,7 +48,7 @@ class WienerFilterCurvature(InvertibleOperatorMixin, EndomorphicOperator):
# ---Added properties and methods---
def _times(self, x, spaces):
def _times(self, x):
res = self.R.adjoint_times(self.N.inverse_times(self.R(x)))
res += self.S.inverse_times(x)
return res
......@@ -29,9 +29,6 @@ class ComposedOperator(LinearOperator):
----------
operators : tuple of NIFTy Operators
The tuple of LinearOperators.
default_spaces : tuple of ints *optional*
Defines on which space(s) of a given field the Operator acts by
default (default: None)
Attributes
......@@ -48,7 +45,7 @@ class ComposedOperator(LinearOperator):
TypeError
Raised if
* an element of the operator list is not an instance of the
LinearOperator-baseclass.
LinearOperator base class.
Notes
-----
......@@ -64,8 +61,8 @@ class ComposedOperator(LinearOperator):
>>> x2 = RGSpace(10)
>>> k1 = RGRGTransformation.get_codomain(x1)
>>> k2 = RGRGTransformation.get_codomain(x2)
>>> FFT1 = FFTOperator(domain=x1, target=k1)
>>> FFT2 = FFTOperator(domain=x2, target=k2)
>>> FFT1 = FFTOperator(domain=(x1,x2), target=(k1,x2), space=0)
>>> FFT2 = FFTOperator(domain=(k1,x2), target=(k1,k2), space=1)
>>> FFT = ComposedOperator((FFT1, FFT2)
>>> f = Field.from_random('normal', domain=(x1,x2))
>>> FFT.times(f)
......@@ -73,80 +70,50 @@ class ComposedOperator(LinearOperator):
"""
# ---Overwritten properties and methods---
def __init__(self, operators, default_spaces=None):
super(ComposedOperator, self).__init__(default_spaces)
def __init__(self, operators):
super(ComposedOperator, self).__init__()
for i in range(1, len(operators)):
if operators[i].domain != operators[i-1].target:
raise ValueError("incompatible domains")
self._operator_store = ()
for op in operators:
if not isinstance(op, LinearOperator):
raise TypeError("The elements of the operator list must be"
"instances of the LinearOperator-baseclass")
"instances of the LinearOperator base class")
self._operator_store += (op,)
def _check_input_compatibility(self, x, spaces, inverse=False):
"""
The input check must be disabled for the ComposedOperator, since it
is not easily forecasteable what the output of an operator-call
will look like.
"""
if spaces is None:
spaces = self.default_spaces
return spaces
# ---Mandatory properties and methods---
@property
def domain(self):
if not hasattr(self, '_domain'):
dom = ()
for op in self._operator_store:
dom += op.domain.domains
self._domain = DomainTuple.make(dom)
return self._domain
return self._operator_store[0].domain
@property
def target(self):
if not hasattr(self, '_target'):
tgt = ()
for op in self._operator_store:
tgt += op.target.domains
self._target = DomainTuple.make(tgt)
return self._target
return self._operator_store[-1].target
@property
def unitary(self):
return False
def _times(self, x, spaces):
return self._times_helper(x, spaces, func='times')
def _times(self, x):
return self._times_helper(x, func='times')
def _adjoint_times(self, x, spaces):
return self._inverse_times_helper(x, spaces, func='adjoint_times')
def _adjoint_times(self, x):
return self._inverse_times_helper(x, func='adjoint_times')
def _inverse_times(self, x, spaces):
return self._inverse_times_helper(x, spaces, func='inverse_times')
def _inverse_times(self, x):
return self._inverse_times_helper(x, func='inverse_times')
def _adjoint_inverse_times(self, x, spaces):
return self._times_helper(x, spaces, func='adjoint_inverse_times')
def _adjoint_inverse_times(self, x):
return self._times_helper(x, func='adjoint_inverse_times')
def _times_helper(self, x, spaces, func):
space_index = 0
if spaces is None:
spaces = range(len(self.domain))
def _times_helper(self, x, func):
for op in self._operator_store:
active_spaces = spaces[space_index:space_index+len(op.domain)]
space_index += len(op.domain)
x = getattr(op, func)(x, spaces=active_spaces)
x = getattr(op, func)(x)
return x
def _inverse_times_helper(self, x, spaces, func):
space_index = 0
if spaces is None:
spaces = range(len(self.target))
rev_spaces = spaces[::-1]
def _inverse_times_helper(self, x, func):
for op in reversed(self._operator_store):
active_spaces = rev_spaces[space_index:space_index+len(op.target)]
space_index += len(op.target)
x = getattr(op, func)(x, spaces=active_spaces[::-1])
x = getattr(op, func)(x)
return x
......@@ -23,7 +23,7 @@ import numpy as np
from ...field import Field
from ...domain_tuple import DomainTuple
from ..endomorphic_operator import EndomorphicOperator
from ...nifty_utilities import cast_iseq_to_tuple
class DiagonalOperator(EndomorphicOperator):
""" NIFTY class for diagonal operators.
......@@ -39,9 +39,6 @@ class DiagonalOperator(EndomorphicOperator):
The diagonal entries of the operator.
copy : boolean
Internal copy of the diagonal (default: True)
default_spaces : tuple of ints *optional*
Defines on which space(s) of a given field the Operator acts by
default (default: None)
Attributes
----------
......@@ -55,9 +52,6 @@ class DiagonalOperator(EndomorphicOperator):
self_adjoint : boolean
Indicates whether the operator is self_adjoint or not.
Raises
------
See Also
--------
EndomorphicOperator
......@@ -66,30 +60,48 @@ class DiagonalOperator(EndomorphicOperator):
# ---Overwritten properties and methods---
def __init__(self, diagonal, copy=True, default_spaces=None):
super(DiagonalOperator, self).__init__(default_spaces)
def __init__(self, diagonal, domain=None, spaces=None, copy=True):
super(DiagonalOperator, self).__init__()
if not isinstance(diagonal, Field):
raise TypeError("Field object required")
if domain is None:
self._domain = diagonal.domain
else:
self._domain = DomainTuple.make(domain)
if spaces is None:
self._spaces = None
if diagonal.domain != self._domain:
raise ValueError("domain mismatch")
else:
self._spaces = cast_iseq_to_tuple(spaces)
nspc = len(self._spaces)
if nspc != len(diagonal.domain.domains):
raise ValueError("spaces and domain must have the same length")
if nspc > len(self._domain.domains):
raise ValueError("too many spaces")
if nspc > len(set(self._spaces)):
raise ValueError("non-unique space indices")
# if nspc==len(self.diagonal.domain.domains, we could do some optimization
for i, j in enumerate(self._spaces):
if diagonal.domain[i] != self._domain[j]:
raise ValueError("domain mismatch")
self._diagonal = diagonal if not copy else diagonal.copy()
self._self_adjoint = None
self._unitary = None
def _times(self, x, spaces):
return self._times_helper(x, spaces, operation=lambda z: z.__mul__)
def _times(self, x):
return self._times_helper(x, lambda z: z.__mul__)
def _adjoint_times(self, x, spaces):
return self._times_helper(x, spaces,
operation=lambda z: z.conjugate().__mul__)
def _adjoint_times(self, x):
return self._times_helper(x, lambda z: z.conjugate().__mul__)
def _inverse_times(self, x, spaces):
return self._times_helper(x, spaces,
operation=lambda z: z.__rtruediv__)
def _inverse_times(self, x):
return self._times_helper(x, lambda z: z.__rtruediv__)
def _adjoint_inverse_times(self, x, spaces):
return self._times_helper(x, spaces,
operation=lambda z:
z.conjugate().__rtruediv__)
def _adjoint_inverse_times(self, x):
return self._times_helper(x, lambda z: z.conjugate().__rtruediv__)
def diagonal(self, copy=True):
""" Returns the diagonal of the Operator.
......@@ -111,7 +123,7 @@ class DiagonalOperator(EndomorphicOperator):
@property
def domain(self):
return self._diagonal.domain
return self._domain
@property
def self_adjoint(self):
......@@ -130,19 +142,13 @@ class DiagonalOperator(EndomorphicOperator):
# ---Added properties and methods---
def _times_helper(self, x, spaces, operation):
# if the domain matches directly
# -> multiply the fields directly
if x.domain == self.domain:
# here the actual multiplication takes place
def _times_helper(self, x, operation):
if self._spaces is None:
return operation(self._diagonal)(x)
if spaces is None:
active_axes = range(len(x.shape))
else:
active_axes = []
for space_index in spaces:
active_axes += x.domain.axes[space_index]
active_axes = []
for space_index in self._spaces:
active_axes += x.domain.axes[space_index]
reshaper = [x.shape[i] if i in active_axes else 1
for i in range(len(x.shape))]
......
......@@ -28,12 +28,6 @@ class EndomorphicOperator(LinearOperator):
LinearOperator. By definition, domain and target are the same in
EndomorphicOperator.
Parameters
----------
default_spaces : tuple of ints *optional*
Defines on which space(s) of a given field the Operator acts by
default (default: None)
Attributes
----------
domain : tuple of DomainObjects, i.e. Spaces and FieldTypes
......@@ -56,37 +50,29 @@ class EndomorphicOperator(LinearOperator):
# ---Overwritten properties and methods---
def inverse_times(self, x, spaces=None):
def inverse_times(self, x):
if self.self_adjoint and self.unitary:
return self.times(x, spaces)
return self.times(x)
else:
return super(EndomorphicOperator, self).inverse_times(
x=x,
spaces=spaces)
return super(EndomorphicOperator, self).inverse_times(x)
def adjoint_times(self, x, spaces=None):
def adjoint_times(self, x):
if self.self_adjoint:
return self.times(x, spaces)
return self.times(x)
else:
return super(EndomorphicOperator, self).adjoint_times(
x=x,
spaces=spaces)
return super(EndomorphicOperator, self).adjoint_times(x)
def adjoint_inverse_times(self, x, spaces=None):
def adjoint_inverse_times(self, x):
if self.self_adjoint:
return self.inverse_times(x, spaces)
return self.inverse_times(x)
else:
return super(EndomorphicOperator, self).adjoint_inverse_times(
x=x,
spaces=spaces)
return super(EndomorphicOperator, self).adjoint_inverse_times(x)
def inverse_adjoint_times(self, x, spaces=None):
def inverse_adjoint_times(self, x):
if self.self_adjoint:
return self.inverse_times(x, spaces)
return self.inverse_times(x)
else:
return super(EndomorphicOperator, self).inverse_adjoint_times(
x=x,
spaces=spaces)
return super(EndomorphicOperator, self).inverse_adjoint_times(x)
# ---Mandatory properties and methods---
......
......@@ -46,6 +46,8 @@ class FFTOperator(LinearOperator):
domain: Space or single-element tuple of Spaces
The domain of the data that is input by "times" and output by
"adjoint_times".
space: the index of the space on which the operator should act
If None, it is set to 0 if domain contains exactly one space
target: Space or single-element tuple of Spaces (optional)
The domain of the data that is output by "times" and input by
"adjoint_times".
......@@ -58,10 +60,10 @@ class FFTOperator(LinearOperator):
Attributes
----------
domain: Tuple of Spaces (with one entry)
domain: Tuple of Spaces
The domain of the data that is input by "times" and output by
"adjoint_times".
target: Tuple of Spaces (with one entry)
target: Tuple of Spaces
The domain of the data that is output by "times" and input by
"adjoint_times".
unitary: bool
......@@ -72,7 +74,6 @@ class FFTOperator(LinearOperator):
------
ValueError:
if "domain" or "target" are not of the proper type.
"""
# ---Class attributes---
......@@ -92,62 +93,53 @@ class FFTOperator(LinearOperator):
# ---Overwritten properties and methods---
def __init__(self, domain, target=None, default_spaces=None):
super(FFTOperator, self).__init__(default_spaces)
def __init__(self, domain, target=None, space=None):
super(FFTOperator, self).__init__()
# Initialize domain and target
self._domain = DomainTuple.make(domain)
if len(self.domain) != 1:
raise ValueError("TransformationOperator accepts only exactly one "
"space as input domain.")
if space is None:
if len(self._domain.domains) != 1:
raise ValueError("need a Field with exactly one domain")
space = 0
space = int(space)
if (space<0) or space>=len(self._domain.domains):
raise ValueError("space index out of range")
self._space = space
adom = self.domain[self._space]
if target is None:
target = (self.domain[0].get_default_codomain(), )
target = [ dom for dom in self.domain ]
target[self._space] = adom.get_default_codomain()
self._target = DomainTuple.make(target)
if len(self.target) != 1:
raise ValueError("TransformationOperator accepts only exactly one "
"space as output target.")
self.domain[0].check_codomain(self.target[0])
self.target[0].check_codomain(self.domain[0])
atgt = self._target[self._space]
adom.check_codomain(atgt)
atgt.check_codomain(adom)
# Create transformation instances
forward_class = self.transformation_dictionary[
(self.domain[0].__class__, self.target[0].__class__)]
(adom.__class__, atgt.__class__)]
backward_class = self.transformation_dictionary[
(self.target[0].__class__, self.domain[0].__class__)]
self._forward_transformation = forward_class(
self.domain[0], self.target[0])
self._backward_transformation = backward_class(
self.target[0], self.domain[0])
def _times_helper(self, x, spaces, other, trafo):
if spaces is None:
# this case means that x lives on only one space, which is
# identical to the space in the domain of `self`. Otherwise the
# input check of LinearOperator would have failed.
axes = x.domain.axes[0]
result_domain