diff --git a/demos/critical_filtering.py b/demos/critical_filtering.py index a296521ed651a15adf7c40ad9de088f80ee1cd37..1acc225074326b43d237285c2d46c024afaaa75e 100644 --- a/demos/critical_filtering.py +++ b/demos/critical_filtering.py @@ -16,17 +16,17 @@ def plot_parameters(m, t, p, p_d): class AdjointFFTResponse(ift.LinearOperator): - def __init__(self, FFT, R, default_spaces=None): - super(AdjointFFTResponse, self).__init__(default_spaces) + def __init__(self, FFT, R): + super(AdjointFFTResponse, self).__init__() self._domain = FFT.target self._target = R.target self.R = R self.FFT = FFT - def _times(self, x, spaces=None): + def _times(self, x): return self.R(self.FFT.adjoint_times(x)) - def _adjoint_times(self, x, spaces=None): + def _adjoint_times(self, x): return self.FFT(self.R.adjoint_times(x)) @property diff --git a/demos/log_normal_wiener_filter.py b/demos/log_normal_wiener_filter.py index 6a47088485498dc7dd5bc11a8e46640106e8a029..3935bbd504b7152e9a77461b5f0aa45b356923f3 100644 --- a/demos/log_normal_wiener_filter.py +++ b/demos/log_normal_wiener_filter.py @@ -35,7 +35,7 @@ if __name__ == "__main__": #mask.val[N10*5:N10*9, N10*5:N10*9] = 0. R = ift.ResponseOperator(signal_space, sigma=(response_sigma,), exposure=(mask,)) #|\label{code:wf_response}| data_domain = R.target[0] - R_harmonic = ift.ComposedOperator([fft, R], default_spaces=[0, 0]) + R_harmonic = ift.ComposedOperator([fft, R]) # Setting up the noise covariance and drawing a random noise realization ndiag = ift.Field(data_domain, mock_signal.var()/signal_to_noise).weight(1) diff --git a/demos/paper_demos/cartesian_wiener_filter.py b/demos/paper_demos/cartesian_wiener_filter.py index c49887b47f034015697c60a3e9903a980165a570..2cc253683a5aae09edc7109010489bb124ef0118 100644 --- a/demos/paper_demos/cartesian_wiener_filter.py +++ b/demos/paper_demos/cartesian_wiener_filter.py @@ -22,7 +22,17 @@ if __name__ == "__main__": signal_space_1 = ift.RGSpace([N_pixels_1], distances=L_1/N_pixels_1) harmonic_space_1 = signal_space_1.get_default_codomain() - fft_1 = ift.FFTOperator(harmonic_space_1, target=signal_space_1) + # Setting up the geometry |\label{code:wf_geometry}| + L_2 = 2. # Total side-length of the domain + N_pixels_2 = 512 # Grid resolution (pixels per axis) + signal_space_2 = ift.RGSpace([N_pixels_2], distances=L_2/N_pixels_2) + harmonic_space_2 = signal_space_2.get_default_codomain() + + signal_domain = ift.DomainTuple.make((signal_space_1, signal_space_2)) + mid_domain = ift.DomainTuple.make((signal_space_1, harmonic_space_2)) + harmonic_domain = ift.DomainTuple.make((harmonic_space_1, harmonic_space_2)) + + fft_1 = ift.FFTOperator(harmonic_domain, space=0) power_space_1 = ift.PowerSpace(harmonic_space_1) mock_power_1 = ift.Field(power_space_1, val=power_spectrum_1(power_space_1.k_lengths)) @@ -39,13 +49,7 @@ if __name__ == "__main__": a = 4 * correlation_length_2 * field_variance_2**2 return a / (1 + k * correlation_length_2) ** 2.5 - # Setting up the geometry |\label{code:wf_geometry}| - L_2 = 2. # Total side-length of the domain - N_pixels_2 = 512 # Grid resolution (pixels per axis) - - signal_space_2 = ift.RGSpace([N_pixels_2], distances=L_2/N_pixels_2) - harmonic_space_2 = signal_space_2.get_default_codomain() - fft_2 = ift.FFTOperator(harmonic_space_2, target=signal_space_2) + fft_2 = ift.FFTOperator(mid_domain, space=1) power_space_2 = ift.PowerSpace(harmonic_space_2) mock_power_2 = ift.Field(power_space_2, val=power_spectrum_2(power_space_2.k_lengths)) @@ -73,11 +77,11 @@ if __name__ == "__main__": mask_2 = ift.Field(signal_space_2, val=1.) mask_2.val[N2_10*7:N2_10*9] = 0. - R = ift.ResponseOperator((signal_space_1, signal_space_2), + R = ift.ResponseOperator(signal_domain,spaces=(0,1), sigma=(response_sigma_1, response_sigma_2), exposure=(mask_1, mask_2)) #|\label{code:wf_response}| data_domain = R.target - R_harmonic = ift.ComposedOperator([fft, R], default_spaces=(0, 1, 0, 1)) + R_harmonic = ift.ComposedOperator([fft, R]) # Setting up the noise covariance and drawing a random noise realization ndiag = ift.Field(data_domain, mock_signal.var()/signal_to_noise).weight(1) diff --git a/demos/paper_demos/wiener_filter.py b/demos/paper_demos/wiener_filter.py index abafd70fa6b20c58155acbe6a76e3c6d170e676d..dcbed1b40f504a03af8fa69adfa173e8e1ae9928 100644 --- a/demos/paper_demos/wiener_filter.py +++ b/demos/paper_demos/wiener_filter.py @@ -34,7 +34,7 @@ if __name__ == "__main__": mask.val[N10*5:N10*9, N10*5:N10*9] = 0. R = ift.ResponseOperator(signal_space, sigma=(response_sigma,), exposure=(mask,)) #|\label{code:wf_response}| data_domain = R.target[0] - R_harmonic = ift.ComposedOperator([fft, R], default_spaces=[0, 0]) + R_harmonic = ift.ComposedOperator([fft, R]) # Setting up the noise covariance and drawing a random noise realization ndiag = ift.Field(data_domain, mock_signal.var()/signal_to_noise).weight(1) diff --git a/demos/wiener_filter_via_curvature.py b/demos/wiener_filter_via_curvature.py index 03e585bc7c194e09501329fb94d644fd90d34e12..858def89e19bf6134c22f45813b51a9510662fe6 100644 --- a/demos/wiener_filter_via_curvature.py +++ b/demos/wiener_filter_via_curvature.py @@ -45,7 +45,7 @@ if __name__ == "__main__": R = ift.ResponseOperator(signal_space, sigma=(response_sigma,)) data_domain = R.target[0] - R_harmonic = ift.ComposedOperator([fft, R], default_spaces=[0, 0]) + R_harmonic = ift.ComposedOperator([fft, R]) N = ift.DiagonalOperator(ift.Field(data_domain,mock_signal.var()/signal_to_noise).weight(1)) noise = ift.Field.from_random(domain=data_domain, diff --git a/demos/wiener_filter_via_hamiltonian.py b/demos/wiener_filter_via_hamiltonian.py index bafd2950b5617b024d42e1a3e107b71591f6e32e..53e0a07785f7c6d1bdedc26e81138df9f2386002 100644 --- a/demos/wiener_filter_via_hamiltonian.py +++ b/demos/wiener_filter_via_hamiltonian.py @@ -5,17 +5,17 @@ np.random.seed(42) class AdjointFFTResponse(ift.LinearOperator): - def __init__(self, FFT, R, default_spaces=None): - super(AdjointFFTResponse, self).__init__(default_spaces) + def __init__(self, FFT, R): + super(AdjointFFTResponse, self).__init__() self._domain = FFT.target self._target = R.target self.R = R self.FFT = FFT - def _times(self, x, spaces=None): + def _times(self, x): return self.R(self.FFT.adjoint_times(x)) - def _adjoint_times(self, x, spaces=None): + def _adjoint_times(self, x): return self.FFT(self.R.adjoint_times(x)) @property diff --git a/nifty/domain_tuple.py b/nifty/domain_tuple.py index b66d6167e88ac973120c4e4ec9b85ca3401c4f4b..b95f2f8f322b6349633b4428a2ef6b7875d64a28 100644 --- a/nifty/domain_tuple.py +++ b/nifty/domain_tuple.py @@ -104,3 +104,9 @@ class DomainTuple(object): if self is x: return False return self._dom != x._dom + + def __str__(self): + res = "DomainTuple, len: " + str(len(self.domains)) + for i in self.domains: + res += "\n" + str(i) + return res diff --git a/nifty/field.py b/nifty/field.py index e1685c24a7524889760a86845f625104c04ace03..ed6a1a75f5dfa16b10bc692be25d29883a1091de 100644 --- a/nifty/field.py +++ b/nifty/field.py @@ -510,8 +510,9 @@ class Field(object): # create a diagonal operator which is capable of taking care of the # axes-matching from .operators.diagonal_operator import DiagonalOperator - diag = DiagonalOperator(y.conjugate(), copy=False) - dotted = diag(x, spaces=spaces) + diag = DiagonalOperator(y.conjugate(), self.domain, + spaces=spaces, copy=False) + dotted = diag(x) return fct*dotted.sum(spaces=spaces) def norm(self): diff --git a/nifty/library/critical_filter/critical_power_curvature.py b/nifty/library/critical_filter/critical_power_curvature.py index 9e57920e4665f196f3c3be34d8e7cba9649b7601..a401ef2fecc61b518ea13b2f4e256032e5608d0d 100644 --- a/nifty/library/critical_filter/critical_power_curvature.py +++ b/nifty/library/critical_filter/critical_power_curvature.py @@ -32,7 +32,7 @@ class CriticalPowerCurvature(InvertibleOperatorMixin, EndomorphicOperator): preconditioner=preconditioner, **kwargs) - def _times(self, x, spaces): + def _times(self, x): return self.T(x) + self.theta(x) # ---Mandatory properties and methods--- diff --git a/nifty/library/log_normal_wiener_filter/log_normal_wiener_filter_curvature.py b/nifty/library/log_normal_wiener_filter/log_normal_wiener_filter_curvature.py index 6c71ad9e8270c1b38ebcca50272b36964da31e11..5e14818110a0a2efd5664afebff9fa772051cb9c 100644 --- a/nifty/library/log_normal_wiener_filter/log_normal_wiener_filter_curvature.py +++ b/nifty/library/log_normal_wiener_filter/log_normal_wiener_filter_curvature.py @@ -58,7 +58,7 @@ class LogNormalWienerFilterCurvature(InvertibleOperatorMixin, # ---Added properties and methods--- - def _times(self, x, spaces): + def _times(self, x): part1 = self.S.inverse_times(x) # part2 = self._exppRNRexppd * x part3 = self._fft.adjoint_times(self._expp_sspace * self._fft(x)) diff --git a/nifty/library/wiener_filter/wiener_filter_curvature.py b/nifty/library/wiener_filter/wiener_filter_curvature.py index 2c2e82d1932aef7a99a32adfe73efddd03523a23..5edcd3fa2fd2eecdce086b06e5dfc6d1650d7a04 100644 --- a/nifty/library/wiener_filter/wiener_filter_curvature.py +++ b/nifty/library/wiener_filter/wiener_filter_curvature.py @@ -48,7 +48,7 @@ class WienerFilterCurvature(InvertibleOperatorMixin, EndomorphicOperator): # ---Added properties and methods--- - def _times(self, x, spaces): + def _times(self, x): res = self.R.adjoint_times(self.N.inverse_times(self.R(x))) res += self.S.inverse_times(x) return res diff --git a/nifty/operators/composed_operator/composed_operator.py b/nifty/operators/composed_operator/composed_operator.py index ade9fd85c6d79a712a22fcf7bef25513f45f0e96..f4a38a02fe0ba0ac7585069528a0db74f310acb4 100644 --- a/nifty/operators/composed_operator/composed_operator.py +++ b/nifty/operators/composed_operator/composed_operator.py @@ -29,9 +29,6 @@ class ComposedOperator(LinearOperator): ---------- operators : tuple of NIFTy Operators The tuple of LinearOperators. - default_spaces : tuple of ints *optional* - Defines on which space(s) of a given field the Operator acts by - default (default: None) Attributes @@ -48,7 +45,7 @@ class ComposedOperator(LinearOperator): TypeError Raised if * an element of the operator list is not an instance of the - LinearOperator-baseclass. + LinearOperator base class. Notes ----- @@ -64,8 +61,8 @@ class ComposedOperator(LinearOperator): >>> x2 = RGSpace(10) >>> k1 = RGRGTransformation.get_codomain(x1) >>> k2 = RGRGTransformation.get_codomain(x2) - >>> FFT1 = FFTOperator(domain=x1, target=k1) - >>> FFT2 = FFTOperator(domain=x2, target=k2) + >>> FFT1 = FFTOperator(domain=(x1,x2), target=(k1,x2), space=0) + >>> FFT2 = FFTOperator(domain=(k1,x2), target=(k1,k2), space=1) >>> FFT = ComposedOperator((FFT1, FFT2) >>> f = Field.from_random('normal', domain=(x1,x2)) >>> FFT.times(f) @@ -73,80 +70,50 @@ class ComposedOperator(LinearOperator): """ # ---Overwritten properties and methods--- - def __init__(self, operators, default_spaces=None): - super(ComposedOperator, self).__init__(default_spaces) + def __init__(self, operators): + super(ComposedOperator, self).__init__() + for i in range(1, len(operators)): + if operators[i].domain != operators[i-1].target: + raise ValueError("incompatible domains") self._operator_store = () for op in operators: if not isinstance(op, LinearOperator): raise TypeError("The elements of the operator list must be" - "instances of the LinearOperator-baseclass") + "instances of the LinearOperator base class") self._operator_store += (op,) - def _check_input_compatibility(self, x, spaces, inverse=False): - """ - The input check must be disabled for the ComposedOperator, since it - is not easily forecasteable what the output of an operator-call - will look like. - """ - if spaces is None: - spaces = self.default_spaces - return spaces - # ---Mandatory properties and methods--- @property def domain(self): - if not hasattr(self, '_domain'): - dom = () - for op in self._operator_store: - dom += op.domain.domains - self._domain = DomainTuple.make(dom) - return self._domain + return self._operator_store[0].domain @property def target(self): - if not hasattr(self, '_target'): - tgt = () - for op in self._operator_store: - tgt += op.target.domains - self._target = DomainTuple.make(tgt) - return self._target + return self._operator_store[-1].target @property def unitary(self): return False - def _times(self, x, spaces): - return self._times_helper(x, spaces, func='times') + def _times(self, x): + return self._times_helper(x, func='times') - def _adjoint_times(self, x, spaces): - return self._inverse_times_helper(x, spaces, func='adjoint_times') + def _adjoint_times(self, x): + return self._inverse_times_helper(x, func='adjoint_times') - def _inverse_times(self, x, spaces): - return self._inverse_times_helper(x, spaces, func='inverse_times') + def _inverse_times(self, x): + return self._inverse_times_helper(x, func='inverse_times') - def _adjoint_inverse_times(self, x, spaces): - return self._times_helper(x, spaces, func='adjoint_inverse_times') + def _adjoint_inverse_times(self, x): + return self._times_helper(x, func='adjoint_inverse_times') - def _times_helper(self, x, spaces, func): - space_index = 0 - if spaces is None: - spaces = range(len(self.domain)) + def _times_helper(self, x, func): for op in self._operator_store: - active_spaces = spaces[space_index:space_index+len(op.domain)] - space_index += len(op.domain) - - x = getattr(op, func)(x, spaces=active_spaces) + x = getattr(op, func)(x) return x - def _inverse_times_helper(self, x, spaces, func): - space_index = 0 - if spaces is None: - spaces = range(len(self.target)) - rev_spaces = spaces[::-1] + def _inverse_times_helper(self, x, func): for op in reversed(self._operator_store): - active_spaces = rev_spaces[space_index:space_index+len(op.target)] - space_index += len(op.target) - - x = getattr(op, func)(x, spaces=active_spaces[::-1]) + x = getattr(op, func)(x) return x diff --git a/nifty/operators/diagonal_operator/diagonal_operator.py b/nifty/operators/diagonal_operator/diagonal_operator.py index 9f4855c6c752803dbb108b016c2ff423fec5e45a..9f815ff2f09a344e9dbceeb548d0d56cc57f3cfa 100644 --- a/nifty/operators/diagonal_operator/diagonal_operator.py +++ b/nifty/operators/diagonal_operator/diagonal_operator.py @@ -23,7 +23,7 @@ import numpy as np from ...field import Field from ...domain_tuple import DomainTuple from ..endomorphic_operator import EndomorphicOperator - +from ...nifty_utilities import cast_iseq_to_tuple class DiagonalOperator(EndomorphicOperator): """ NIFTY class for diagonal operators. @@ -39,9 +39,6 @@ class DiagonalOperator(EndomorphicOperator): The diagonal entries of the operator. copy : boolean Internal copy of the diagonal (default: True) - default_spaces : tuple of ints *optional* - Defines on which space(s) of a given field the Operator acts by - default (default: None) Attributes ---------- @@ -55,9 +52,6 @@ class DiagonalOperator(EndomorphicOperator): self_adjoint : boolean Indicates whether the operator is self_adjoint or not. - Raises - ------ - See Also -------- EndomorphicOperator @@ -66,30 +60,48 @@ class DiagonalOperator(EndomorphicOperator): # ---Overwritten properties and methods--- - def __init__(self, diagonal, copy=True, default_spaces=None): - super(DiagonalOperator, self).__init__(default_spaces) + def __init__(self, diagonal, domain=None, spaces=None, copy=True): + super(DiagonalOperator, self).__init__() if not isinstance(diagonal, Field): raise TypeError("Field object required") + if domain is None: + self._domain = diagonal.domain + else: + self._domain = DomainTuple.make(domain) + if spaces is None: + self._spaces = None + if diagonal.domain != self._domain: + raise ValueError("domain mismatch") + else: + self._spaces = cast_iseq_to_tuple(spaces) + nspc = len(self._spaces) + if nspc != len(diagonal.domain.domains): + raise ValueError("spaces and domain must have the same length") + if nspc > len(self._domain.domains): + raise ValueError("too many spaces") + if nspc > len(set(self._spaces)): + raise ValueError("non-unique space indices") + # if nspc==len(self.diagonal.domain.domains, we could do some optimization + for i, j in enumerate(self._spaces): + if diagonal.domain[i] != self._domain[j]: + raise ValueError("domain mismatch") + self._diagonal = diagonal if not copy else diagonal.copy() self._self_adjoint = None self._unitary = None - def _times(self, x, spaces): - return self._times_helper(x, spaces, operation=lambda z: z.__mul__) + def _times(self, x): + return self._times_helper(x, lambda z: z.__mul__) - def _adjoint_times(self, x, spaces): - return self._times_helper(x, spaces, - operation=lambda z: z.conjugate().__mul__) + def _adjoint_times(self, x): + return self._times_helper(x, lambda z: z.conjugate().__mul__) - def _inverse_times(self, x, spaces): - return self._times_helper(x, spaces, - operation=lambda z: z.__rtruediv__) + def _inverse_times(self, x): + return self._times_helper(x, lambda z: z.__rtruediv__) - def _adjoint_inverse_times(self, x, spaces): - return self._times_helper(x, spaces, - operation=lambda z: - z.conjugate().__rtruediv__) + def _adjoint_inverse_times(self, x): + return self._times_helper(x, lambda z: z.conjugate().__rtruediv__) def diagonal(self, copy=True): """ Returns the diagonal of the Operator. @@ -111,7 +123,7 @@ class DiagonalOperator(EndomorphicOperator): @property def domain(self): - return self._diagonal.domain + return self._domain @property def self_adjoint(self): @@ -130,19 +142,13 @@ class DiagonalOperator(EndomorphicOperator): # ---Added properties and methods--- - def _times_helper(self, x, spaces, operation): - # if the domain matches directly - # -> multiply the fields directly - if x.domain == self.domain: - # here the actual multiplication takes place + def _times_helper(self, x, operation): + if self._spaces is None: return operation(self._diagonal)(x) - if spaces is None: - active_axes = range(len(x.shape)) - else: - active_axes = [] - for space_index in spaces: - active_axes += x.domain.axes[space_index] + active_axes = [] + for space_index in self._spaces: + active_axes += x.domain.axes[space_index] reshaper = [x.shape[i] if i in active_axes else 1 for i in range(len(x.shape))] diff --git a/nifty/operators/endomorphic_operator/endomorphic_operator.py b/nifty/operators/endomorphic_operator/endomorphic_operator.py index 26e61afc4558507ebf62857be3a6057d6e3c871e..9e75ef2715c5f71c1d17e2833e945f15e2f81680 100644 --- a/nifty/operators/endomorphic_operator/endomorphic_operator.py +++ b/nifty/operators/endomorphic_operator/endomorphic_operator.py @@ -28,12 +28,6 @@ class EndomorphicOperator(LinearOperator): LinearOperator. By definition, domain and target are the same in EndomorphicOperator. - Parameters - ---------- - default_spaces : tuple of ints *optional* - Defines on which space(s) of a given field the Operator acts by - default (default: None) - Attributes ---------- domain : tuple of DomainObjects, i.e. Spaces and FieldTypes @@ -56,37 +50,29 @@ class EndomorphicOperator(LinearOperator): # ---Overwritten properties and methods--- - def inverse_times(self, x, spaces=None): + def inverse_times(self, x): if self.self_adjoint and self.unitary: - return self.times(x, spaces) + return self.times(x) else: - return super(EndomorphicOperator, self).inverse_times( - x=x, - spaces=spaces) + return super(EndomorphicOperator, self).inverse_times(x) - def adjoint_times(self, x, spaces=None): + def adjoint_times(self, x): if self.self_adjoint: - return self.times(x, spaces) + return self.times(x) else: - return super(EndomorphicOperator, self).adjoint_times( - x=x, - spaces=spaces) + return super(EndomorphicOperator, self).adjoint_times(x) - def adjoint_inverse_times(self, x, spaces=None): + def adjoint_inverse_times(self, x): if self.self_adjoint: - return self.inverse_times(x, spaces) + return self.inverse_times(x) else: - return super(EndomorphicOperator, self).adjoint_inverse_times( - x=x, - spaces=spaces) + return super(EndomorphicOperator, self).adjoint_inverse_times(x) - def inverse_adjoint_times(self, x, spaces=None): + def inverse_adjoint_times(self, x): if self.self_adjoint: - return self.inverse_times(x, spaces) + return self.inverse_times(x) else: - return super(EndomorphicOperator, self).inverse_adjoint_times( - x=x, - spaces=spaces) + return super(EndomorphicOperator, self).inverse_adjoint_times(x) # ---Mandatory properties and methods--- diff --git a/nifty/operators/fft_operator/fft_operator.py b/nifty/operators/fft_operator/fft_operator.py index 19f4023042785c64eb843573b8521801baf0d984..b3bc0bd40acba6af60cb8355523b8cef733c4331 100644 --- a/nifty/operators/fft_operator/fft_operator.py +++ b/nifty/operators/fft_operator/fft_operator.py @@ -46,6 +46,8 @@ class FFTOperator(LinearOperator): domain: Space or single-element tuple of Spaces The domain of the data that is input by "times" and output by "adjoint_times". + space: the index of the space on which the operator should act + If None, it is set to 0 if domain contains exactly one space target: Space or single-element tuple of Spaces (optional) The domain of the data that is output by "times" and input by "adjoint_times". @@ -58,10 +60,10 @@ class FFTOperator(LinearOperator): Attributes ---------- - domain: Tuple of Spaces (with one entry) + domain: Tuple of Spaces The domain of the data that is input by "times" and output by "adjoint_times". - target: Tuple of Spaces (with one entry) + target: Tuple of Spaces The domain of the data that is output by "times" and input by "adjoint_times". unitary: bool @@ -72,7 +74,6 @@ class FFTOperator(LinearOperator): ------ ValueError: if "domain" or "target" are not of the proper type. - """ # ---Class attributes--- @@ -92,62 +93,53 @@ class FFTOperator(LinearOperator): # ---Overwritten properties and methods--- - def __init__(self, domain, target=None, default_spaces=None): - super(FFTOperator, self).__init__(default_spaces) + def __init__(self, domain, target=None, space=None): + super(FFTOperator, self).__init__() # Initialize domain and target self._domain = DomainTuple.make(domain) - if len(self.domain) != 1: - raise ValueError("TransformationOperator accepts only exactly one " - "space as input domain.") - + if space is None: + if len(self._domain.domains) != 1: + raise ValueError("need a Field with exactly one domain") + space = 0 + space = int(space) + if (space<0) or space>=len(self._domain.domains): + raise ValueError("space index out of range") + self._space = space + + adom = self.domain[self._space] if target is None: - target = (self.domain[0].get_default_codomain(), ) + target = [ dom for dom in self.domain ] + target[self._space] = adom.get_default_codomain() + self._target = DomainTuple.make(target) - if len(self.target) != 1: - raise ValueError("TransformationOperator accepts only exactly one " - "space as output target.") - self.domain[0].check_codomain(self.target[0]) - self.target[0].check_codomain(self.domain[0]) + atgt = self._target[self._space] + adom.check_codomain(atgt) + atgt.check_codomain(adom) # Create transformation instances forward_class = self.transformation_dictionary[ - (self.domain[0].__class__, self.target[0].__class__)] + (adom.__class__, atgt.__class__)] backward_class = self.transformation_dictionary[ - (self.target[0].__class__, self.domain[0].__class__)] - - self._forward_transformation = forward_class( - self.domain[0], self.target[0]) - - self._backward_transformation = backward_class( - self.target[0], self.domain[0]) - - def _times_helper(self, x, spaces, other, trafo): - if spaces is None: - # this case means that x lives on only one space, which is - # identical to the space in the domain of `self`. Otherwise the - # input check of LinearOperator would have failed. - axes = x.domain.axes[0] - result_domain = other - else: - spaces = utilities.cast_iseq_to_tuple(spaces) - result_domain = list(x.domain) - result_domain[spaces[0]] = other[0] - axes = x.domain.axes[spaces[0]] + (atgt.__class__, adom.__class__)] + + self._forward_transformation = forward_class(adom, atgt) + self._backward_transformation = backward_class(atgt, adom) + + def _times_helper(self, x, other, trafo): + axes = x.domain.axes[self._space] new_val, fct = trafo.transform(x.val, axes=axes) - res = Field(result_domain, new_val, copy=False) + res = Field(other, new_val, copy=False) if fct != 1.: res *= fct return res - def _times(self, x, spaces): - return self._times_helper(x, spaces, self.target, - self._forward_transformation) + def _times(self, x): + return self._times_helper(x, self.target, self._forward_transformation) - def _adjoint_times(self, x, spaces): - return self._times_helper(x, spaces, self.domain, - self._backward_transformation) + def _adjoint_times(self, x): + return self._times_helper(x, self.domain, self._backward_transformation) # ---Mandatory properties and methods--- diff --git a/nifty/operators/invertible_operator_mixin/invertible_operator_mixin.py b/nifty/operators/invertible_operator_mixin/invertible_operator_mixin.py index f048b406bca4d9224b1a8b8345a5d39013a5847f..3b0454ca15bed85bec1d7adc415610a9694995ce 100644 --- a/nifty/operators/invertible_operator_mixin/invertible_operator_mixin.py +++ b/nifty/operators/invertible_operator_mixin/invertible_operator_mixin.py @@ -46,7 +46,7 @@ class InvertibleOperatorMixin(object): self.__backward_x0 = backward_x0 super(InvertibleOperatorMixin, self).__init__(*args, **kwargs) - def _times(self, x, spaces): + def _times(self, x): if self.__forward_x0 is not None: x0 = self.__forward_x0 else: @@ -58,7 +58,7 @@ class InvertibleOperatorMixin(object): preconditioner=self._preconditioner) return result.position - def _adjoint_times(self, x, spaces): + def _adjoint_times(self, x): if self.__backward_x0 is not None: x0 = self.__backward_x0 else: @@ -70,7 +70,7 @@ class InvertibleOperatorMixin(object): preconditioner=self._preconditioner) return result.position - def _inverse_times(self, x, spaces): + def _inverse_times(self, x): if self.__backward_x0 is not None: x0 = self.__backward_x0 else: @@ -82,7 +82,7 @@ class InvertibleOperatorMixin(object): preconditioner=self._preconditioner) return result.position - def _adjoint_inverse_times(self, x, spaces): + def _adjoint_inverse_times(self, x): if self.__forward_x0 is not None: x0 = self.__forward_x0 else: diff --git a/nifty/operators/laplace_operator/laplace_operator.py b/nifty/operators/laplace_operator/laplace_operator.py index 75b77874bf5e72918820d02d8b5b5afbe765268f..80feddc330e185e8e94eafad9fa8a75b914c1cea 100644 --- a/nifty/operators/laplace_operator/laplace_operator.py +++ b/nifty/operators/laplace_operator/laplace_operator.py @@ -25,7 +25,7 @@ from ... import nifty_utilities as utilities class LaplaceOperator(EndomorphicOperator): - """A irregular LaplaceOperator with free boundary and excluding monopole. + """An irregular LaplaceOperator with free boundary and excluding monopole. This LaplaceOperator implements the second derivative of a Field in PowerSpace on logarithmic or linear scale with vanishing curvature at the @@ -39,15 +39,24 @@ class LaplaceOperator(EndomorphicOperator): default : True """ - def __init__(self, domain, default_spaces=None, logarithmic=True): - super(LaplaceOperator, self).__init__(default_spaces) + def __init__(self, domain, space=None, logarithmic=True): + super(LaplaceOperator, self).__init__() self._domain = DomainTuple.make(domain) - if len(self.domain) != 1 or not isinstance(self.domain[0], PowerSpace): - raise ValueError("The domain must contain exactly one PowerSpace.") + if space is None: + if len(self._domain.domains) != 1: + raise ValueError("need a Field with exactly one domain") + space = 0 + space = int(space) + if (space<0) or space>=len(self._domain.domains): + raise ValueError("space index out of range") + self._space = space + + if not isinstance(self._domain[self._space], PowerSpace): + raise ValueError("Operator must act on a PowerSpace.") self._logarithmic = bool(logarithmic) - pos = self.domain[0].k_lengths.copy() + pos = self.domain[self._space].k_lengths.copy() if self.logarithmic: pos[1:] = np.log(pos[1:]) pos[0] = pos[1]-1. @@ -85,15 +94,8 @@ class LaplaceOperator(EndomorphicOperator): def logarithmic(self): return self._logarithmic - def _times(self, x, spaces): - if spaces is None: - # this case means that x lives on only one space, which is - # identical to the space in the domain of `self`. Otherwise the - # input check of LinearOperator would have failed. - axes = x.domain.axes[0] - else: - spaces = utilities.cast_iseq_to_tuple(spaces) - axes = x.domain.axes[spaces[0]] + def _times(self, x): + axes = x.domain.axes[self._space] axis = axes[0] nval = len(self._dposc) prefix = (slice(None),) * axis @@ -109,17 +111,10 @@ class LaplaceOperator(EndomorphicOperator): ret /= np.sqrt(dposc) ret[prefix + (slice(None, 2),)] = 0. ret[prefix + (-1,)] = 0. - return Field(self.domain, val=ret).weight(power=-0.5, spaces=spaces) - - def _adjoint_times(self, x, spaces): - if spaces is None: - # this case means that x lives on only one space, which is - # identical to the space in the domain of `self`. Otherwise the - # input check of LinearOperator would have failed. - axes = x.domain.axes[0] - else: - spaces = utilities.cast_iseq_to_tuple(spaces) - axes = x.domain.axes[spaces[0]] + return Field(self.domain, val=ret).weight(-0.5, spaces=(self._space,)) + + def _adjoint_times(self, x): + axes = x.domain.axes[self._space] axis = axes[0] nval = len(self._dposc) prefix = (slice(None),) * axis @@ -127,7 +122,7 @@ class LaplaceOperator(EndomorphicOperator): sl_r = prefix + (slice(1, None),) # "right" slice dpos = self._dpos.reshape((1,)*axis + (nval-1,)) dposc = self._dposc.reshape((1,)*axis + (nval,)) - y = x.copy().weight(power=0.5).val + y = x.copy().weight(power=0.5, spaces=(self._space,)).val y /= np.sqrt(dposc) y[prefix + (slice(None, 2),)] = 0. y[prefix + (-1,)] = 0. @@ -136,4 +131,4 @@ class LaplaceOperator(EndomorphicOperator): ret[sl_l] = deriv ret[prefix + (-1,)] = 0. ret[sl_r] -= deriv - return Field(self.domain, val=ret).weight(-1, spaces=spaces) + return Field(self.domain, val=ret).weight(-1, spaces=(self._space,)) diff --git a/nifty/operators/linear_operator/linear_operator.py b/nifty/operators/linear_operator/linear_operator.py index 39b1487b2d8f8ead6ba18eff3c9166f109738395..cc0b225fff681798446192dfaaa1af78b3b2a3be 100644 --- a/nifty/operators/linear_operator/linear_operator.py +++ b/nifty/operators/linear_operator/linear_operator.py @@ -30,16 +30,8 @@ class LinearOperator(with_metaclass( """NIFTY base class for linear operators. The base NIFTY operator class is an abstract class from which - other specific operator subclasses, including those preimplemented - in NIFTY (e.g. the EndomorphicOperator, ProjectionOperator, - DiagonalOperator, SmoothingOperator, ResponseOperator, - PropagatorOperator, ComposedOperator) are derived. + other specific operator subclasses are derived. - Parameters - ---------- - default_spaces : tuple of ints *optional* - Defines on which space(s) of a given field the Operator acts by - default (default: None) Attributes ---------- @@ -57,17 +49,10 @@ class LinearOperator(with_metaclass( * domain is not defined * target is not defined * unitary is not set to (True/False) - - Notes - ----- - All Operators wihtin NIFTy are linear and must therefore be a subclasses of - the LinearOperator. A LinearOperator must have the attributes domain, - target and unitary to be properly defined. - """ - def __init__(self, default_spaces=None): - self._default_spaces = default_spaces + def __init__(self): + pass @abc.abstractproperty def domain(self): @@ -78,7 +63,6 @@ class LinearOperator(with_metaclass( base class must have this attribute. """ - raise NotImplementedError @abc.abstractproperty @@ -88,9 +72,7 @@ class LinearOperator(with_metaclass( The domain on which the Operator's output Field lives. Every Operator which inherits from the abstract LinearOperator base class must have this attribute. - """ - raise NotImplementedError @abc.abstractproperty @@ -100,19 +82,13 @@ class LinearOperator(with_metaclass( States whether the Operator is unitary or not. Every Operator which inherits from the abstract LinearOperator base class must have this attribute. - """ - raise NotImplementedError - @property - def default_spaces(self): - return self._default_spaces - - def __call__(self, x, spaces=None): - return self.times(x, spaces) + def __call__(self, x): + return self.times(x) - def times(self, x, spaces=None): + def times(self, x): """ Applies the Operator to a given Field. Operator and Field have to live over the same domain. @@ -121,21 +97,16 @@ class LinearOperator(with_metaclass( ---------- x : Field The input Field. - spaces : tuple of ints - Defines on which space(s) of the given Field the Operator acts. Returns ------- out : Field The processed Field living on the target-domain. - """ + self._check_input_compatibility(x) + return self._times(x) - spaces = self._check_input_compatibility(x, spaces) - y = self._times(x, spaces) - return y - - def inverse_times(self, x, spaces=None): + def inverse_times(self, x): """ Applies the inverse-Operator to a given Field. Operator and Field have to live over the same domain. @@ -144,28 +115,23 @@ class LinearOperator(with_metaclass( ---------- x : Field The input Field. - spaces : tuple of ints - Defines on which space(s) of the given Field the Operator acts. Returns ------- out : Field The processed Field living on the target-domain. - """ - - spaces = self._check_input_compatibility(x, spaces, inverse=True) - + self._check_input_compatibility(x, inverse=True) try: - y = self._inverse_times(x, spaces) + y = self._inverse_times(x) except(NotImplementedError): if (self.unitary): - y = self._adjoint_times(x, spaces) + y = self._adjoint_times(x) else: raise return y - def adjoint_times(self, x, spaces=None): + def adjoint_times(self, x): """ Applies the adjoint-Operator to a given Field. Operator and Field have to live over the same domain. @@ -174,31 +140,27 @@ class LinearOperator(with_metaclass( ---------- x : Field applies the Operator to the given Field - spaces : tuple of ints - defines on which space of the given Field the Operator acts Returns ------- out : Field The processed Field living on the target-domain. - """ if self.unitary: - return self.inverse_times(x, spaces) - - spaces = self._check_input_compatibility(x, spaces, inverse=True) + return self.inverse_times(x) + self._check_input_compatibility(x, inverse=True) try: - y = self._adjoint_times(x, spaces) + y = self._adjoint_times(x) except(NotImplementedError): if (self.unitary): - y = self._inverse_times(x, spaces) + y = self._inverse_times(x) else: raise return y - def adjoint_inverse_times(self, x, spaces=None): + def adjoint_inverse_times(self, x): """ Applies the adjoint-inverse Operator to a given Field. Operator and Field have to live over the same domain. @@ -207,8 +169,6 @@ class LinearOperator(with_metaclass( ---------- x : Field applies the Operator to the given Field - spaces : tuple of ints - defines on which space of the given Field the Operator acts Returns ------- @@ -219,74 +179,43 @@ class LinearOperator(with_metaclass( ----- If the operator has an `inverse` then the inverse adjoint is identical to the adjoint inverse. We provide both names for convenience. - """ - - spaces = self._check_input_compatibility(x, spaces) - + self._check_input_compatibility(x) try: - y = self._adjoint_inverse_times(x, spaces) + y = self._adjoint_inverse_times(x) except(NotImplementedError): if self.unitary: - y = self._times(x, spaces) + y = self._times(x) else: raise return y - def inverse_adjoint_times(self, x, spaces=None): - return self.adjoint_inverse_times(x, spaces) + def inverse_adjoint_times(self, x): + return self.adjoint_inverse_times(x) - def _times(self, x, spaces): + def _times(self, x): raise NotImplementedError( "no generic instance method 'times'.") - def _adjoint_times(self, x, spaces): + def _adjoint_times(self, x): raise NotImplementedError( "no generic instance method 'adjoint_times'.") - def _inverse_times(self, x, spaces): + def _inverse_times(self, x): raise NotImplementedError( "no generic instance method 'inverse_times'.") - def _adjoint_inverse_times(self, x, spaces): + def _adjoint_inverse_times(self, x): raise NotImplementedError( "no generic instance method 'adjoint_inverse_times'.") - def _check_input_compatibility(self, x, spaces, inverse=False): + def _check_input_compatibility(self, x, inverse=False): if not isinstance(x, Field): raise ValueError("supplied object is not a `Field`.") - if spaces is None and self.default_spaces is not None: - if not inverse: - spaces = self.default_spaces - else: - spaces = self.default_spaces[::-1] - - # sanitize the `spaces` input - if spaces is not None: - spaces = utilities.cast_iseq_to_tuple(spaces) - - # if the operator's domain is set to something, there are two valid - # cases: - # 1. Case: - # The user specifies with `spaces` that the operators domain should - # be applied to certain spaces in the domain-tuple of x. - # 2. Case: - # The domains of self and x match completely. - - self_domain = self.target if inverse else self.domain - - if spaces is None: - if self_domain != x.domain: - raise ValueError("The operator's and and field's domains " - "don't match.") - else: - for i, space_index in enumerate(spaces): - if x.domain[space_index] != self_domain[i]: - raise ValueError("The operator's and and field's domains " - "don't match.") - - return spaces + if x.domain != (self.target if inverse else self.domain): + raise ValueError("The operator's and and field's domains " + "don't match.") def __repr__(self): return str(self.__class__) diff --git a/nifty/operators/response_operator/response_operator.py b/nifty/operators/response_operator/response_operator.py index d215b0928efe57599f22419c898bc0c45da81e21..e6fc1ce8ec7346bd77d746b3341e9272374ff40e 100644 --- a/nifty/operators/response_operator/response_operator.py +++ b/nifty/operators/response_operator/response_operator.py @@ -43,12 +43,10 @@ class ResponseOperator(LinearOperator): ValueError: raised if: * len of sigma-list and exposure-list are not equal - """ - def __init__(self, domain, sigma=[1.], exposure=[1.], - default_spaces=None): - super(ResponseOperator, self).__init__(default_spaces) + def __init__(self, domain, sigma=[1.], exposure=[1.], spaces=None): + super(ResponseOperator, self).__init__() if len(sigma) != len(exposure): raise ValueError("Length of smoothing kernel and length of" @@ -57,15 +55,22 @@ class ResponseOperator(LinearOperator): self._domain = DomainTuple.make(domain) - kernel_smoothing = [FFTSmoothingOperator(self._domain[x], sigma[x]) + if spaces is None: + spaces = range(len(self._domain)) + + kernel_smoothing = [FFTSmoothingOperator(self._domain, sigma[x], + space=spaces[x]) for x in range(nsigma)] - kernel_exposure = [DiagonalOperator(Field(self._domain[x],exposure[x])) + kernel_exposure = [DiagonalOperator(Field(self._domain[spaces[x]], + exposure[x]), + domain=self._domain, + spaces=(spaces[x],)) for x in range(nsigma)] self._composed_kernel = ComposedOperator(kernel_smoothing) self._composed_exposure = ComposedOperator(kernel_exposure) - target_list = [FieldArray(x.shape) for x in self.domain] + target_list = [FieldArray(self._domain[i].shape) for i in spaces] self._target = DomainTuple.make(target_list) @property @@ -80,15 +85,15 @@ class ResponseOperator(LinearOperator): def unitary(self): return False - def _times(self, x, spaces): - res = self._composed_kernel.times(x, spaces) - res = self._composed_exposure.times(res, spaces) + def _times(self, x): + res = self._composed_kernel.times(x) + res = self._composed_exposure.times(res) # removing geometric information return Field(self._target, val=res.val) - def _adjoint_times(self, x, spaces): + def _adjoint_times(self, x): # setting correct spaces res = Field(self.domain, val=x.val) - res = self._composed_exposure.adjoint_times(res, spaces) + res = self._composed_exposure.adjoint_times(res) res = res.weight(power=-1) - return self._composed_kernel.adjoint_times(res, spaces) + return self._composed_kernel.adjoint_times(res) diff --git a/nifty/operators/smoothing_operator/direct_smoothing_operator.py b/nifty/operators/smoothing_operator/direct_smoothing_operator.py index 0efa637a4af9031bf1a6afc0946583e3f8881884..8a99b2368d9f97543e82b7c7c20fe41e19b40502 100644 --- a/nifty/operators/smoothing_operator/direct_smoothing_operator.py +++ b/nifty/operators/smoothing_operator/direct_smoothing_operator.py @@ -11,27 +11,28 @@ from ... import Field, DomainTuple class DirectSmoothingOperator(EndomorphicOperator): def __init__(self, domain, sigma, log_distances=False, - default_spaces=None): - super(DirectSmoothingOperator, self).__init__(default_spaces) + space=None): + super(DirectSmoothingOperator, self).__init__() self._domain = DomainTuple.make(domain) - if len(self._domain) != 1: - raise ValueError("DirectSmoothingOperator only accepts exactly one" - " space as input domain.") + if space is None: + if len(self._domain.domains) != 1: + raise ValueError("need a Field with exactly one domain") + space = 0 + space = int(space) + if (space<0) or space>=len(self._domain.domains): + raise ValueError("space index out of range") + self._space = space self._sigma = float(sigma) self._log_distances = log_distances self._effective_smoothing_width = 3.01 - def _times(self, x, spaces): + def _times(self, x): if self._sigma == 0: return x.copy() - # the domain of the smoothing operator contains exactly one space. - # Hence, if spaces is None, but we passed LinearOperator's - # _check_input_compatibility, we know that x is also solely defined - # on that space - return self._smooth(x, (0,) if spaces is None else spaces) + return self._smooth(x) # ---Mandatory properties and methods--- @property @@ -90,16 +91,12 @@ class DirectSmoothingOperator(EndomorphicOperator): return ibegin, nval, wgt - def _smooth(self, x, spaces): + def _smooth(self, x): # infer affected axes - # we rely on the knowledge that `spaces` is a tuple with length 1. - affected_axes = x.domain.axes[spaces[0]] - if len(affected_axes) != 1: - raise ValueError("By this implementation only one-dimensional " - "spaces can be smoothed directly.") + affected_axes = x.domain.axes[self._space] axis = affected_axes[0] - distances = x.domain[spaces[0]].get_k_length_array() + distances = x.domain[self._space].get_k_length_array() if self._log_distances: distances = np.log(np.maximum(distances, 1e-15)) diff --git a/nifty/operators/smoothing_operator/fft_smoothing_operator.py b/nifty/operators/smoothing_operator/fft_smoothing_operator.py index 6a257ac834ed31f08e57c955455c50f34e9d9e15..3384c26a755ed313fc330592e1ae0ca9ed89aa47 100644 --- a/nifty/operators/smoothing_operator/fft_smoothing_operator.py +++ b/nifty/operators/smoothing_operator/fft_smoothing_operator.py @@ -8,40 +8,36 @@ from ..fft_operator import FFTOperator from ... import DomainTuple class FFTSmoothingOperator(EndomorphicOperator): + def __init__(self, domain, sigma, space=None): + super(FFTSmoothingOperator, self).__init__() - def __init__(self, domain, sigma, - default_spaces=None): - super(FFTSmoothingOperator, self).__init__(default_spaces) - - self._domain = DomainTuple.make(domain) - if len(self._domain) != 1: - raise ValueError("SmoothingOperator only accepts exactly one " - "space as input domain.") - + dom = DomainTuple.make(domain) self._sigma = float(sigma) - if self._sigma == 0.: - return - - self._transformator = FFTOperator(self._domain) - codomain = self._domain[0].get_default_codomain() + if space is None: + if len(dom.domains) != 1: + raise ValueError("need a Field with exactly one domain") + space = 0 + space = int(space) + if (space<0) or space>=len(dom.domains): + raise ValueError("space index out of range") + self._space = space + + self._transformator = FFTOperator(dom, space=space) + codomain = self._transformator.domain[space].get_default_codomain() self._kernel = codomain.get_k_length_array() smoother = codomain.get_fft_smoothing_kernel_function(self._sigma) self._kernel = smoother(self._kernel) - def _times(self, x, spaces): + def _times(self, x): if self._sigma == 0: return x.copy() - # the domain of the smoothing operator contains exactly one space. - # Hence, if spaces is None, but we passed LinearOperator's - # _check_input_compatibility, we know that x is also solely defined - # on that space - return self._smooth(x, (0,) if spaces is None else spaces) + return self._smooth(x) # ---Mandatory properties and methods--- @property def domain(self): - return self._domain + return self._transformator.domain @property def self_adjoint(self): @@ -53,11 +49,11 @@ class FFTSmoothingOperator(EndomorphicOperator): # ---Added properties and methods--- - def _smooth(self, x, spaces): + def _smooth(self, x): # transform to the (global-)default codomain and perform all remaining # steps therein - transformed_x = self._transformator(x, spaces=spaces) - coaxes = transformed_x.domain.axes[spaces[0]] + transformed_x = self._transformator(x) + coaxes = transformed_x.domain.axes[self._space] # now, apply the kernel to transformed_x # this is done node-locally utilizing numpy's reshaping in order to @@ -68,4 +64,4 @@ class FFTSmoothingOperator(EndomorphicOperator): transformed_x *= np.reshape(self._kernel, reshaper) - return self._transformator.adjoint_times(transformed_x, spaces=spaces) + return self._transformator.adjoint_times(transformed_x) diff --git a/nifty/operators/smoothness_operator/smoothness_operator.py b/nifty/operators/smoothness_operator/smoothness_operator.py index cc27f9153aa3b45253f2156ce7bc36b80593c65c..9aa01c99a07ec1ee591164aa1e3ec72196a17b73 100644 --- a/nifty/operators/smoothness_operator/smoothness_operator.py +++ b/nifty/operators/smoothness_operator/smoothness_operator.py @@ -29,35 +29,20 @@ class SmoothnessOperator(EndomorphicOperator): # ---Overwritten properties and methods--- - def __init__(self, domain, strength=1., logarithmic=True, - default_spaces=None): - - super(SmoothnessOperator, self).__init__(default_spaces=default_spaces) - - self._domain = DomainTuple.make(domain) - if len(self.domain) != 1: - raise ValueError("The domain must contain exactly one PowerSpace.") - - if not isinstance(self.domain[0], PowerSpace): - raise TypeError("The domain must contain exactly one PowerSpace.") + def __init__(self, domain, strength=1., logarithmic=True, space=None): + super(SmoothnessOperator, self).__init__() + self._laplace = LaplaceOperator(domain, + logarithmic=logarithmic, space=space) if strength <= 0: raise ValueError("ERROR: invalid sigma.") - self._strength = strength - self._laplace = LaplaceOperator(domain=self.domain, - logarithmic=logarithmic) - # ---Mandatory properties and methods--- - @property - def target(self): - return self._domain - @property def domain(self): - return self._domain + return self._laplace._domain @property def unitary(self): @@ -71,10 +56,9 @@ class SmoothnessOperator(EndomorphicOperator): def self_adjoint(self): return False - def _times(self, x, spaces): + def _times(self, x): if self._strength != 0: - result = self._laplace.adjoint_times(self._laplace(x, spaces), - spaces) + result = self._laplace.adjoint_times(self._laplace(x)) result *= self._strength**2 else: result = Field(x.domain, 0., x.dtype) diff --git a/nifty/sugar.py b/nifty/sugar.py index 9902bc55c7b23c3f8af6abb0be33a65c4b8cdce3..4122e52cddabce8b97e2f294728e9762eef6b461 100644 --- a/nifty/sugar.py +++ b/nifty/sugar.py @@ -115,10 +115,10 @@ def generate_posterior_sample(mean, covariance): def create_composed_fft_operator(domain, codomain=None, all_to='other'): fft_op_list = [] - space_index_list = [] if codomain is None: codomain = [None]*len(domain) + interdomain = list(domain.domains) for i, space in enumerate(domain): cospace = codomain[i] if not isinstance(space, Space): @@ -126,7 +126,11 @@ def create_composed_fft_operator(domain, codomain=None, all_to='other'): if (all_to == 'other' or (all_to == 'position' and space.harmonic) or (all_to == 'harmonic' and not space.harmonic)): - fft_op_list += [FFTOperator(domain=space, target=cospace)] - space_index_list += [i] - result = ComposedOperator(fft_op_list, default_spaces=space_index_list) - return result + if codomain[i] is None: + interdomain[i] = domain[i].get_default_codomain() + else: + interdomain[i] = codomain[i] + fft_op_list += [FFTOperator(domain=domain, target=interdomain, + space=i)] + domain = interdomain + return ComposedOperator(fft_op_list) diff --git a/test/test_operators/test_composed_operator.py b/test/test_operators/test_composed_operator.py index 27f97606d3c674215b727acd678a5d50cb9f0dc0..fe801c3834c9d9b020f016bd64d7cfe3737e4a39 100644 --- a/test/test_operators/test_composed_operator.py +++ b/test/test_operators/test_composed_operator.py @@ -16,24 +16,13 @@ from test.common import expand class ComposedOperator_Tests(unittest.TestCase): spaces = generate_spaces() - @expand(product(spaces, spaces)) - def test_property(self, space1, space2): - rand1 = Field.from_random('normal', domain=space1) - rand2 = Field.from_random('normal', domain=space2) - op1 = DiagonalOperator(rand1) - op2 = DiagonalOperator(rand2) - op = ComposedOperator((op1, op2)) - if op.domain != (op1.domain[0], op2.domain[0]): - raise TypeError - if op.unitary != False: - raise ValueError - @expand(product(spaces,spaces)) def test_times_adjoint_times(self, space1, space2): + cspace = (space1, space2) diag1 = Field.from_random('normal', domain=space1) diag2 = Field.from_random('normal', domain=space2) - op1 = DiagonalOperator(diag1) - op2 = DiagonalOperator(diag2) + op1 = DiagonalOperator(diag1, cspace, spaces=(0,)) + op2 = DiagonalOperator(diag2, cspace, spaces=(1,)) op = ComposedOperator((op1, op2)) @@ -46,10 +35,11 @@ class ComposedOperator_Tests(unittest.TestCase): @expand(product(spaces, spaces)) def test_times_inverse_times(self, space1, space2): + cspace = (space1, space2) diag1 = Field.from_random('normal', domain=space1) diag2 = Field.from_random('normal', domain=space2) - op1 = DiagonalOperator(diag1) - op2 = DiagonalOperator(diag2) + op1 = DiagonalOperator(diag1, cspace, spaces=(0,)) + op2 = DiagonalOperator(diag2, cspace, spaces=(1,)) op = ComposedOperator((op1, op2)) diff --git a/test/test_operators/test_fft_operator.py b/test/test_operators/test_fft_operator.py index 38b65ccf61a96106dd9903636f8b1af4043b8106..5e03b1fac853de5c096f3e68772732e726a876e7 100644 --- a/test/test_operators/test_fft_operator.py +++ b/test/test_operators/test_fft_operator.py @@ -82,10 +82,7 @@ class FFTOperatorTests(unittest.TestCase): base): tol = _get_rtol(dtype) a = [a1, a2, a3] = [RGSpace((32,)), RGSpace((4, 4)), RGSpace((5, 6))] - fft = FFTOperator(domain=a[index], - default_spaces=(index,)) - fft._forward_transformation.harmonic_base = base - fft._backward_transformation.harmonic_base = base + fft = FFTOperator(domain=a, space=index) inp = Field.from_random(domain=(a1, a2, a3), random_type='normal', std=7, mean=3, dtype=dtype)