diff --git a/demos/wiener_filter_easy.py b/demos/wiener_filter_easy.py index ab9d181965636f0f8bb96435ff59cfb329fe2473..87c2831215187ec33863bc44c88a17ea9320f3cd 100644 --- a/demos/wiener_filter_easy.py +++ b/demos/wiener_filter_easy.py @@ -2,36 +2,6 @@ import numpy as np import nifty2go as ift -# Note that the constructor of PropagatorOperator takes as arguments the -# response R and noise covariance N operating on signal space and signal -# covariance operating on harmonic space. -class PropagatorOperator(ift.InversionEnabler, ift.EndomorphicOperator): - def __init__(self, R, N, Sh, inverter): - ift.InversionEnabler.__init__(self, inverter) - ift.EndomorphicOperator.__init__(self) - - self.R = R - self.N = N - self.Sh = Sh - self.fft = ift.FFTOperator(R.domain, target=Sh.domain[0]) - - def _inverse_times(self, x): - return self.R.adjoint_times(self.N.inverse_times(self.R(x))) \ - + self.fft.adjoint_times(self.Sh.inverse_times(self.fft(x))) - - @property - def domain(self): - return self.R.domain - - @property - def unitary(self): - return False - - @property - def self_adjoint(self): - return True - - if __name__ == "__main__": # Set up physical constants # Total length of interval or volume the field lives on, e.g. in meters @@ -85,6 +55,6 @@ if __name__ == "__main__": j = R.adjoint_times(N.inverse_times(d)) IC = ift.GradientNormController(iteration_limit=500, tol_abs_gradnorm=0.1) inverter = ift.ConjugateGradient(controller=IC) - D = PropagatorOperator(Sh=Sh, N=N, R=R, inverter=inverter) - + D = (R.adjoint*N.inverse*R + fft.adjoint*Sh.inverse*fft).inverse + D = ift.InversionEnabler(D, inverter) m = D(j) diff --git a/demos/wiener_filter_via_hamiltonian.py b/demos/wiener_filter_via_hamiltonian.py index 354cbb76aa8dfaaff156bafbf634348c857da39f..5676d2ef330b9530a89bd9bf7e0be8217f9ecf11 100644 --- a/demos/wiener_filter_via_hamiltonian.py +++ b/demos/wiener_filter_via_hamiltonian.py @@ -71,7 +71,7 @@ if __name__ == "__main__": n_samples = 50 for i in range(n_samples): - sample = fft(D.generate_posterior_sample(m)) + sample = fft(D.generate_posterior_sample() + m) sample_variance += sample**2 sample_mean += sample sample_mean /= n_samples diff --git a/nifty/library/critical_power_curvature.py b/nifty/library/critical_power_curvature.py index d84b3f91c7c45667d8fb3162c78fc28c037d9d0c..788312ff6491654aad05d37c4ed8a45d810420ae 100644 --- a/nifty/library/critical_power_curvature.py +++ b/nifty/library/critical_power_curvature.py @@ -1,7 +1,7 @@ from ..operators import EndomorphicOperator, InversionEnabler, DiagonalOperator -class CriticalPowerCurvature(InversionEnabler, EndomorphicOperator): +class CriticalPowerCurvature(EndomorphicOperator): """The curvature of the CriticalPowerEnergy. This operator implements the second derivative of the @@ -17,22 +17,17 @@ class CriticalPowerCurvature(InversionEnabler, EndomorphicOperator): """ def __init__(self, theta, T, inverter): - EndomorphicOperator.__init__(self) - self._theta = DiagonalOperator(theta) - InversionEnabler.__init__(self, inverter, self._theta.inverse_times) - self._T = T - - def _times(self, x): - return self._T(x) + self._theta(x) + super(CriticalPowerCurvature, self).__init__() + theta = DiagonalOperator(theta) + self._op = InversionEnabler(T+theta, inverter, theta.inverse_times) @property - def domain(self): - return self._theta.domain + def capability(self): + return self._op.capability - @property - def self_adjoint(self): - return True + def apply(self, x, mode): + return self._op.apply(x, mode) @property - def unitary(self): - return False + def domain(self): + return self._op.domain diff --git a/nifty/library/critical_power_energy.py b/nifty/library/critical_power_energy.py index 6d007fb1079f06037e01b23999dc6138297b141f..3253b12d8beadb7be6a39c067e3187ca1c9dd761 100644 --- a/nifty/library/critical_power_energy.py +++ b/nifty/library/critical_power_energy.py @@ -70,7 +70,7 @@ class CriticalPowerEnergy(Energy): if self.D is not None: w = Field.zeros(self.position.domain, dtype=self.m.dtype) for i in range(self.samples): - sample = self.D.generate_posterior_sample(self.m) + sample = self.D.generate_posterior_sample() + self.m w += P(abs(sample)**2) w *= 1./self.samples diff --git a/nifty/library/log_normal_wiener_filter_curvature.py b/nifty/library/log_normal_wiener_filter_curvature.py index 6446a136a52d3668d145e6b87162bb1269049993..92cbf76b76ad94e23fee8a26f3a38e8b71efdd12 100644 --- a/nifty/library/log_normal_wiener_filter_curvature.py +++ b/nifty/library/log_normal_wiener_filter_curvature.py @@ -3,7 +3,7 @@ from ..utilities import memo from ..field import exp -class LogNormalWienerFilterCurvature(InversionEnabler, EndomorphicOperator): +class LogNormalWienerFilterCurvature(EndomorphicOperator): """The curvature of the LogNormalWienerFilterEnergy. This operator implements the second derivative of the @@ -21,33 +21,54 @@ class LogNormalWienerFilterCurvature(InversionEnabler, EndomorphicOperator): The prior signal covariance """ + class _Helper(EndomorphicOperator): + def __init__(self, R, N, S, position, fft4exp): + super(LogNormalWienerFilterCurvature._Helper, self).__init__() + self.R = R + self.N = N + self.S = S + self.position = position + self._fft = fft4exp + + @property + def domain(self): + return self.S.domain + + @property + def capability(self): + return self.TIMES | self.ADJOINT_TIMES + + def apply(self, x, mode): + self._check_input(x, mode) + part1 = self.S.inverse_times(x) + part3 = self._fft.adjoint_times(self._expp_sspace * self._fft(x)) + part3 = self._fft.adjoint_times( + self._expp_sspace * + self._fft(self.R.adjoint_times( + self.N.inverse_times(self.R(part3))))) + return part1 + part3 + + @property + @memo + def _expp_sspace(self): + return exp(self._fft(self.position)) + def __init__(self, R, N, S, position, fft4exp, inverter): - InversionEnabler.__init__(self, inverter) - EndomorphicOperator.__init__(self) - self.R = R - self.N = N - self.S = S - self.position = position - self._fft = fft4exp + super(LogNormalWienerFilterCurvature, self).__init__() + self._op = self._Helper(R, N, S, position, fft4exp) + self._op = InversionEnabler(self._op, inverter) @property def domain(self): - return self.S.domain + return self._op.domain @property - def self_adjoint(self): - return True - - def _times(self, x): - part1 = self.S.inverse_times(x) - part3 = self._fft.adjoint_times(self._expp_sspace * self._fft(x)) - part3 = self._fft.adjoint_times( - self._expp_sspace * - self._fft(self.R.adjoint_times( - self.N.inverse_times(self.R(part3))))) - return part1 + part3 + def capability(self): + return self._op.capability @property - @memo def _expp_sspace(self): - return exp(self._fft(self.position)) + return self._op._op._expp_sspace + + def apply(self, x, mode): + return self._op.apply(x, mode) diff --git a/nifty/library/noise_energy.py b/nifty/library/noise_energy.py index 292508d1594767090a80fa15b729c614d02b73b9..61c2630520877913204c4f3e6ad7d3451b5c9374 100644 --- a/nifty/library/noise_energy.py +++ b/nifty/library/noise_energy.py @@ -27,7 +27,7 @@ class NoiseEnergy(Energy): if samples is None or samples == 0: sample_list = [m] else: - sample_list = [D.generate_posterior_sample(m) + sample_list = [D.generate_posterior_sample() + m for _ in range(samples)] self.sample_list = sample_list self.inverter = inverter diff --git a/nifty/library/nonlinear_power_curvature.py b/nifty/library/nonlinear_power_curvature.py index 8e30ef54dc4bb769e4102ae08bd90c99d2f8096c..26ee1bc3a1c093a2907f8a5456acc0d0219b005f 100644 --- a/nifty/library/nonlinear_power_curvature.py +++ b/nifty/library/nonlinear_power_curvature.py @@ -2,42 +2,60 @@ from ..operators import EndomorphicOperator, InversionEnabler from .response_operators import LinearizedPowerResponse -class NonlinearPowerCurvature(InversionEnabler, EndomorphicOperator): +class NonlinearPowerCurvature(EndomorphicOperator): + class _Helper(EndomorphicOperator): + def __init__(self, position, FFT, Instrument, nonlinearity, + Projection, N, T, sample_list): + super(NonlinearPowerCurvature._Helper, self).__init__() + self.N = N + self.FFT = FFT + self.Instrument = Instrument + self.T = T + self.sample_list = sample_list + self.position = position + self.Projection = Projection + self.nonlinearity = nonlinearity + + @property + def domain(self): + return self.position.domain + + @property + def capability(self): + return self.TIMES + + def apply(self, x, mode): + self._check_input(x, mode) + result = None + for sample in self.sample_list: + if result is None: + result = self._sample_times(x, sample) + else: + result += self._sample_times(x, sample) + result *= 1./len(self.sample_list) + return result + self.T(x) + + def _sample_times(self, x, sample): + LinearizedResponse = LinearizedPowerResponse( + self.Instrument, self.nonlinearity, self.FFT, self.Projection, + self.position, sample) + return LinearizedResponse.adjoint_times( + self.N.inverse_times(LinearizedResponse(x))) def __init__(self, position, FFT, Instrument, nonlinearity, Projection, N, T, sample_list, inverter): - InversionEnabler.__init__(self, inverter) - EndomorphicOperator.__init__(self) - self.N = N - self.FFT = FFT - self.Instrument = Instrument - self.T = T - self.sample_list = sample_list - self.position = position - self.Projection = Projection - self.nonlinearity = nonlinearity + super(NonlinearPowerCurvature, self).__init__() + self._op = self._Helper(position, FFT, Instrument, nonlinearity, + Projection, N, T, sample_list) + self._op = InversionEnabler(self._op, inverter) @property def domain(self): - return self.position.domain + return self._op.domain @property - def self_adjoint(self): - return True - - def _times(self, x): - result = None - for sample in self.sample_list: - if result is None: - result = self._sample_times(x, sample) - else: - result += self._sample_times(x, sample) - result *= 1./len(self.sample_list) - return result + self.T(x) - - def _sample_times(self, x, sample): - LinearizedResponse = LinearizedPowerResponse( - self.Instrument, self.nonlinearity, self.FFT, self.Projection, - self.position, sample) - return LinearizedResponse.adjoint_times( - self.N.inverse_times(LinearizedResponse(x))) + def capability(self): + return self._op.capability + + def apply(self, x, mode): + return self._op.apply(x, mode) diff --git a/nifty/library/nonlinear_power_energy.py b/nifty/library/nonlinear_power_energy.py index 2bffe3e12704a2f77c937fe57cea896213c145e4..97ec9f78670680dcfb401020d59e91a35526811f 100644 --- a/nifty/library/nonlinear_power_energy.py +++ b/nifty/library/nonlinear_power_energy.py @@ -53,7 +53,7 @@ class NonlinearPowerEnergy(Energy): if samples is None or samples == 0: sample_list = [m] else: - sample_list = [D.generate_posterior_sample(m) + sample_list = [D.generate_posterior_sample() + m for _ in range(samples)] self.sample_list = sample_list self.inverter = inverter diff --git a/nifty/library/response_operators.py b/nifty/library/response_operators.py index 436efb8838d55e6b1f9cb592bb03687b25fb9842..98c0671aaefd7ad0c251e35c3343b042bfe19546 100644 --- a/nifty/library/response_operators.py +++ b/nifty/library/response_operators.py @@ -31,6 +31,14 @@ class LinearizedSignalResponse(LinearOperator): def target(self): return self.Instrument.target + @property + def capability(self): + return self.TIMES | self.ADJOINT_TIMES + + def apply(self, x, mode): + self._check_input(x, mode) + return self._times(x) if mode & self.TIMES else self._adjoint_times(x) + class LinearizedPowerResponse(LinearOperator): def __init__(self, Instrument, nonlinearity, FFT, Projection, t, m): @@ -70,3 +78,11 @@ class LinearizedPowerResponse(LinearOperator): @property def target(self): return self.Instrument.target + + @property + def capability(self): + return self.TIMES | self.ADJOINT_TIMES + + def apply(self, x, mode): + self._check_input(x, mode) + return self._times(x) if mode & self.TIMES else self._adjoint_times(x) diff --git a/nifty/library/wiener_filter_curvature.py b/nifty/library/wiener_filter_curvature.py index 8a6008f129305c4583cb2705b9b0a5b5ea6b83db..b6d72dfe1ff276f2e63b2f7141331c6be29204c3 100644 --- a/nifty/library/wiener_filter_curvature.py +++ b/nifty/library/wiener_filter_curvature.py @@ -3,7 +3,7 @@ from ..field import Field, sqrt from ..sugar import power_analyze, power_synthesize -class WienerFilterCurvature(InversionEnabler, EndomorphicOperator): +class WienerFilterCurvature(EndomorphicOperator): """The curvature of the WienerFilterEnergy. This operator implements the second derivative of the @@ -22,26 +22,25 @@ class WienerFilterCurvature(InversionEnabler, EndomorphicOperator): """ def __init__(self, R, N, S, inverter): - EndomorphicOperator.__init__(self) - InversionEnabler.__init__(self, inverter, S.times) + super(WienerFilterCurvature, self).__init__() self.R = R self.N = N self.S = S + op = R.adjoint*N.inverse*R + S.inverse + self._op = InversionEnabler(op, inverter, S.times) @property def domain(self): return self.S.domain @property - def self_adjoint(self): - return True + def capability(self): + return self._op.capability - def _times(self, x): - res = self.R.adjoint_times(self.N.inverse_times(self.R(x))) - res += self.S.inverse_times(x) - return res + def apply(self, x, mode): + return self._op.apply(x, mode) - def generate_posterior_sample(self, mean): + def generate_posterior_sample(self): """ Generates a posterior sample from a Gaussian distribution with given mean and covariance. @@ -49,11 +48,6 @@ class WienerFilterCurvature(InversionEnabler, EndomorphicOperator): reconstruction of a mock signal in order to obtain residuals of the right correlation which are added to the given mean. - Parameters - ---------- - mean : Field - the mean of the posterior Gaussian distribution - Returns ------- sample : Field @@ -74,5 +68,5 @@ class WienerFilterCurvature(InversionEnabler, EndomorphicOperator): mock_j = self.R.adjoint_times(self.N.inverse_times(mock_data)) mock_m = self.inverse_times(mock_j) - sample = mock_signal - mock_m + mean + sample = mock_signal - mock_m return sample diff --git a/nifty/library/wiener_filter_energy.py b/nifty/library/wiener_filter_energy.py index 3bf42e8046ce2087f929101422757ce23cc5283d..5f43b2347b581b74bff06ac1d79e6fda6c439f14 100644 --- a/nifty/library/wiener_filter_energy.py +++ b/nifty/library/wiener_filter_energy.py @@ -27,7 +27,7 @@ class WienerFilterEnergy(Energy): self.R = R self.N = N self.S = S - self._curvature = WienerFilterCurvature(R, N, S, inverter=inverter) + self._curvature = WienerFilterCurvature(R, N, S, inverter) self._inverter = inverter if _j is None: _j = self.R.adjoint_times(self.N.inverse_times(d)) diff --git a/nifty/operators/__init__.py b/nifty/operators/__init__.py index ed25e1c63cd1eee1b2b698d64b5abd1348b0843f..b6d015c96410cb11847c0b6c3333181305543f9f 100644 --- a/nifty/operators/__init__.py +++ b/nifty/operators/__init__.py @@ -12,3 +12,7 @@ from .laplace_operator import LaplaceOperator from .smoothness_operator import SmoothnessOperator from .power_projection_operator import PowerProjectionOperator from .dof_projection_operator import DOFProjectionOperator +from .chain_operator import ChainOperator +from .sum_operator import SumOperator +from .inverse_operator import InverseOperator +from .adjoint_operator import AdjointOperator diff --git a/nifty/operators/adjoint_operator.py b/nifty/operators/adjoint_operator.py new file mode 100644 index 0000000000000000000000000000000000000000..6748c9d4de8d8f7f751c3483e122bff63f946d79 --- /dev/null +++ b/nifty/operators/adjoint_operator.py @@ -0,0 +1,40 @@ +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# +# Copyright(C) 2013-2017 Max-Planck-Society +# +# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik +# and financially supported by the Studienstiftung des deutschen Volkes. + +from .linear_operator import LinearOperator + + +class AdjointOperator(LinearOperator): + def __init__(self, op): + super(AdjointOperator, self).__init__() + self._op = op + + @property + def domain(self): + return self._op.target + + @property + def target(self): + return self._op.domain + + @property + def capability(self): + return self._adjointCapability[self._op.capability] + + def apply(self, x, mode): + return self._op.apply(x, self._adjointMode[mode]) diff --git a/nifty/operators/chain_operator.py b/nifty/operators/chain_operator.py new file mode 100644 index 0000000000000000000000000000000000000000..b9967ee5317de5007e121788a73569035644ce7d --- /dev/null +++ b/nifty/operators/chain_operator.py @@ -0,0 +1,46 @@ +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# +# Copyright(C) 2013-2017 Max-Planck-Society +# +# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik +# and financially supported by the Studienstiftung des deutschen Volkes. + +from .linear_operator import LinearOperator + + +class ChainOperator(LinearOperator): + def __init__(self, op1, op2): + super(ChainOperator, self).__init__() + if op2.target != op1.domain: + raise ValueError("domain mismatch") + self._op1 = op1 + self._op2 = op2 + + @property + def domain(self): + return self._op2.domain + + @property + def target(self): + return self._op1.target + + @property + def capability(self): + return self._op1.capability & self._op2.capability + + def apply(self, x, mode): + self._check_mode(mode) + if mode == self.TIMES or mode == self.ADJOINT_INVERSE_TIMES: + return self._op1.apply(self._op2.apply(x, mode), mode) + return self._op2.apply(self._op1.apply(x, mode), mode) diff --git a/nifty/operators/composed_operator.py b/nifty/operators/composed_operator.py index c77896140f788c3b02523ba625abbd16dbe069c5..cc6872382bc0d1a31e534e63b1d691731e0f939b 100644 --- a/nifty/operators/composed_operator.py +++ b/nifty/operators/composed_operator.py @@ -16,7 +16,6 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. -from builtins import range from .linear_operator import LinearOperator @@ -30,7 +29,6 @@ class ComposedOperator(LinearOperator): operators : tuple of NIFTy Operators The tuple of LinearOperators. - Attributes ---------- domain : DomainTuple @@ -44,6 +42,7 @@ class ComposedOperator(LinearOperator): self._operator_store = () old_op = None + self._capability = operators[0].capability for op in operators: if not isinstance(op, LinearOperator): raise TypeError("The elements of the operator list must be" @@ -51,7 +50,10 @@ class ComposedOperator(LinearOperator): if old_op is not None and op.domain != old_op.target: raise ValueError("incompatible domains") self._operator_store += (op,) + self._capability &= op.capability old_op = op + if self._capability == 0: + raise ValueError("composed operator does not support any mode") @property def domain(self): @@ -61,24 +63,16 @@ class ComposedOperator(LinearOperator): def target(self): return self._operator_store[-1].target - def _times(self, x): - return self._times_helper(x, func='times') - - def _adjoint_times(self, x): - return self._inverse_times_helper(x, func='adjoint_times') - - def _inverse_times(self, x): - return self._inverse_times_helper(x, func='inverse_times') - - def _adjoint_inverse_times(self, x): - return self._times_helper(x, func='adjoint_inverse_times') - - def _times_helper(self, x, func): - for op in self._operator_store: - x = getattr(op, func)(x) - return x - - def _inverse_times_helper(self, x, func): - for op in reversed(self._operator_store): - x = getattr(op, func)(x) + @property + def capability(self): + return self._capability + + def apply(self, x, mode): + self._check_mode(mode) + if mode == self.TIMES or mode == self.ADJOINT_INVERSE_TIMES: + for op in self._operator_store: + x = op.apply(x, mode) + else: + for op in reversed(self._operator_store): + x = op.apply(x, mode) return x diff --git a/nifty/operators/diagonal_operator.py b/nifty/operators/diagonal_operator.py index fe3fe444a1fde4d557977a0bf30f40eba75dbf32..e239e5bc5547ba9bb636d38c61a8d9c4005359c2 100644 --- a/nifty/operators/diagonal_operator.py +++ b/nifty/operators/diagonal_operator.py @@ -48,13 +48,6 @@ class DiagonalOperator(EndomorphicOperator): ---------- domain : DomainTuple The domain on which the Operator's input Field lives. - target : DomainTuple - The domain in which the outcome of the operator lives. As the Operator - is endomorphic this is the same as its domain. - unitary : boolean - Indicates whether the Operator is unitary or not. - self_adjoint : boolean - Indicates whether the operator is self-adjoint or not. NOTE: the fields given to __init__ and returned from .diagonal() are considered to be non-bare, i.e. during operator application, no additional @@ -114,20 +107,23 @@ class DiagonalOperator(EndomorphicOperator): else: self._ldiag = dobj.local_data(self._diagonal.val) - self._self_adjoint = None - self._unitary = None + def apply(self, x, mode): + self._check_input(x, mode) - def _times(self, x): - return Field(x.domain, val=x.val*self._ldiag) - - def _adjoint_times(self, x): - return Field(x.domain, val=x.val*self._ldiag.conj()) - - def _inverse_times(self, x): - return Field(x.domain, val=x.val/self._ldiag) - - def _adjoint_inverse_times(self, x): - return Field(x.domain, val=x.val/self._ldiag.conj()) + if mode == self.TIMES: + return Field(x.domain, val=x.val*self._ldiag) + elif mode == self.ADJOINT_TIMES: + if np.issubdtype(self._ldiag.dtype, np.floating): + return Field(x.domain, val=x.val*self._ldiag) + else: + return Field(x.domain, val=x.val*self._ldiag.conj()) + elif mode == self.INVERSE_TIMES: + return Field(x.domain, val=x.val/self._ldiag) + else: + if np.issubdtype(self._ldiag.dtype, np.floating): + return Field(x.domain, val=x.val/self._ldiag) + else: + return Field(x.domain, val=x.val/self._ldiag.conj()) def diagonal(self): """ Returns the diagonal of the Operator.""" @@ -138,16 +134,6 @@ class DiagonalOperator(EndomorphicOperator): return self._domain @property - def self_adjoint(self): - if self._self_adjoint is None: - if not np.issubdtype(self._diagonal.dtype, np.complexfloating): - self._self_adjoint = True - else: - self._self_adjoint = (self._diagonal.val.imag == 0).all() - return self._self_adjoint - - @property - def unitary(self): - if self._unitary is None: - self._unitary = (abs(self._diagonal.val) == 1.).all() - return self._unitary + def capability(self): + return (self.TIMES | self.ADJOINT_TIMES | + self.INVERSE_TIMES | self.ADJOINT_INVERSE_TIMES) diff --git a/nifty/operators/direct_smoothing_operator.py b/nifty/operators/direct_smoothing_operator.py index da382eb7f23ed8303fbe36c757cde878dae967fb..0091016f4764068ecdd000bd7d29a70974d4ef53 100644 --- a/nifty/operators/direct_smoothing_operator.py +++ b/nifty/operators/direct_smoothing_operator.py @@ -28,7 +28,8 @@ class DirectSmoothingOperator(EndomorphicOperator): self._ibegin, self._nval, self._wgt = self._precompute(distances) - def _times(self, x): + def apply(self, x, mode): + self._check_input(x, mode) if self._sigma == 0: return x.copy() @@ -39,8 +40,8 @@ class DirectSmoothingOperator(EndomorphicOperator): return self._domain @property - def self_adjoint(self): - return True + def capability(self): + return self.TIMES | self.ADJOINT_TIMES def _precompute(self, x): """ Does precomputations for Gaussian smoothing on a 1D irregular grid. diff --git a/nifty/operators/dof_projection_operator.py b/nifty/operators/dof_projection_operator.py index 10daa73335f419b7a6d4ee9b8e995cdd5ef36a91..884b929d4e0438665544d7e618f098c16cdf37ac 100644 --- a/nifty/operators/dof_projection_operator.py +++ b/nifty/operators/dof_projection_operator.py @@ -88,6 +88,12 @@ class DOFProjectionOperator(LinearOperator): oarr[()] = arr[(slice(None), self._dofdex, slice(None))] return res + def apply(self, x, mode): + self._check_input(x, mode) + if mode == self.TIMES: + return self._times(x) + return self._adjoint_times(x) + @property def domain(self): return self._domain @@ -95,3 +101,7 @@ class DOFProjectionOperator(LinearOperator): @property def target(self): return self._target + + @property + def capability(self): + return self.TIMES | self.ADJOINT_TIMES diff --git a/nifty/operators/endomorphic_operator.py b/nifty/operators/endomorphic_operator.py index 184297337b0376440d897119243a906b2cc2a82f..db0a3c8595fb88a28ba752db4045386c9e963e39 100644 --- a/nifty/operators/endomorphic_operator.py +++ b/nifty/operators/endomorphic_operator.py @@ -16,7 +16,6 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. -import abc from .linear_operator import LinearOperator @@ -34,33 +33,8 @@ class EndomorphicOperator(LinearOperator): target : DomainTuple The domain in which the outcome of the operator lives. As the Operator is endomorphic this is the same as its domain. - self_adjoint : boolean - Indicates whether the operator is self_adjoint or not. """ - def inverse_times(self, x): - if self.self_adjoint and self.unitary: - return self.times(x) - else: - return super(EndomorphicOperator, self).inverse_times(x) - - def adjoint_times(self, x): - if self.self_adjoint: - return self.times(x) - else: - return super(EndomorphicOperator, self).adjoint_times(x) - - def adjoint_inverse_times(self, x): - if self.self_adjoint: - return self.inverse_times(x) - else: - return super(EndomorphicOperator, self).adjoint_inverse_times(x) - @property def target(self): return self.domain - - @abc.abstractproperty - def self_adjoint(self): - """ States whether the Operator is self_adjoint or not.""" - raise NotImplementedError diff --git a/nifty/operators/fft_operator.py b/nifty/operators/fft_operator.py index 8369fe82681a265312dfc6eb47c60bb0da7d29a8..4ce98e3694c637255cd5041f00d14a5ef414ce9c 100644 --- a/nifty/operators/fft_operator.py +++ b/nifty/operators/fft_operator.py @@ -63,9 +63,6 @@ class FFTOperator(LinearOperator): target: Tuple of Spaces The domain of the data that is output by "times" and input by "adjoint_times". - unitary: bool - Returns True if the operator is unitary (currently only the case if - the domain and codomain are RGSpaces), else False. Raises ------ @@ -107,10 +104,8 @@ class FFTOperator(LinearOperator): res = self._trafo.transform(x) return res - def _times(self, x): - return self._times_helper(x) - - def _adjoint_times(self, x): + def apply(self, x, mode): + self._check_input(x, mode) return self._times_helper(x) @property @@ -122,5 +117,8 @@ class FFTOperator(LinearOperator): return self._target @property - def unitary(self): - return self._trafo.unitary + def capability(self): + res = self.TIMES | self.ADJOINT_TIMES + if self._trafo.unitary: + res |= self.INVERSE_TIMES | self.ADJOINT_INVERSE_TIMES + return res diff --git a/nifty/operators/fft_smoothing_operator.py b/nifty/operators/fft_smoothing_operator.py index 7752c9171ea30e15a5615bb5d8373e6ceadc43d3..e7e790cef6bcad159b231d9bb254c6e19c43adb9 100644 --- a/nifty/operators/fft_smoothing_operator.py +++ b/nifty/operators/fft_smoothing_operator.py @@ -22,7 +22,8 @@ class FFTSmoothingOperator(EndomorphicOperator): ddom[self._space] = codomain self._diag = DiagonalOperator(kernel, ddom, self._space) - def _times(self, x): + def apply(self, x, mode): + self._check_input(x, mode) if self._sigma == 0: return x.copy() @@ -33,5 +34,5 @@ class FFTSmoothingOperator(EndomorphicOperator): return self._FFT.domain @property - def self_adjoint(self): - return True + def capability(self): + return self.TIMES | self.ADJOINT_TIMES diff --git a/nifty/operators/inverse_operator.py b/nifty/operators/inverse_operator.py new file mode 100644 index 0000000000000000000000000000000000000000..6a119b5e137058aaee2925ce91b078c51150c192 --- /dev/null +++ b/nifty/operators/inverse_operator.py @@ -0,0 +1,40 @@ +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# +# Copyright(C) 2013-2017 Max-Planck-Society +# +# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik +# and financially supported by the Studienstiftung des deutschen Volkes. + +from .linear_operator import LinearOperator + + +class InverseOperator(LinearOperator): + def __init__(self, op): + super(InverseOperator, self).__init__() + self._op = op + + @property + def domain(self): + return self._op.target + + @property + def target(self): + return self._op.domain + + @property + def capability(self): + return self._inverseCapability[self._op.capability] + + def apply(self, x, mode): + return self._op.apply(x, self._inverseMode[mode]) diff --git a/nifty/operators/inversion_enabler.py b/nifty/operators/inversion_enabler.py index 87bcec6e5240402d5058581e5a0bb0d80a54a7f0..7b5231a295dcfd69615655b6df199b6397e1fd6d 100644 --- a/nifty/operators/inversion_enabler.py +++ b/nifty/operators/inversion_enabler.py @@ -19,31 +19,41 @@ from ..minimization.quadratic_energy import QuadraticEnergy from ..minimization.iteration_controller import IterationController from ..field import Field, dobj +from .linear_operator import LinearOperator -class InversionEnabler(object): - - def __init__(self, inverter, preconditioner=None): +class InversionEnabler(LinearOperator): + def __init__(self, op, inverter, preconditioner=None): super(InversionEnabler, self).__init__() + self._op = op self._inverter = inverter self._preconditioner = preconditioner - def _operation(self, x, op, tdom): + @property + def domain(self): + return self._op.domain + + @property + def target(self): + return self._op.target + + @property + def capability(self): + return self._addInverse[self._op.capability] + + def apply(self, x, mode): + self._check_mode(mode) + if self._op.capability & mode: + return self._op.apply(x, mode) + + tdom = self._tgt(mode) x0 = Field.zeros(tdom, dtype=x.dtype) - energy = QuadraticEnergy(A=op, b=x, position=x0) + + def func(x): + return self._op.apply(x, self._inverseMode[mode]) + + energy = QuadraticEnergy(A=func, b=x, position=x0) r, stat = self._inverter(energy, preconditioner=self._preconditioner) if stat != IterationController.CONVERGED: dobj.mprint("Error detected during operator inversion") return r.position - - def _times(self, x): - return self._operation(x, self._inverse_times, self.target) - - def _adjoint_times(self, x): - return self._operation(x, self._adjoint_inverse_times, self.domain) - - def _inverse_times(self, x): - return self._operation(x, self._times, self.domain) - - def _adjoint_inverse_times(self, x): - return self._operation(x, self._adjoint_times, self.target) diff --git a/nifty/operators/laplace_operator.py b/nifty/operators/laplace_operator.py index 09bfa90fe9f25dafb0498446694b8513a793a009..1135884bfbbc734dad07fa5e55affe14012cdb48 100644 --- a/nifty/operators/laplace_operator.py +++ b/nifty/operators/laplace_operator.py @@ -71,8 +71,8 @@ class LaplaceOperator(EndomorphicOperator): return self._domain @property - def self_adjoint(self): - return False + def capability(self): + return self.TIMES | self.ADJOINT_TIMES @property def logarithmic(self): @@ -129,3 +129,9 @@ class LaplaceOperator(EndomorphicOperator): if dobj.distaxis(yf) != dobj.distaxis(x.val): ret = dobj.redistribute(ret, dist=dobj.distaxis(x.val)) return Field(self.domain, val=ret).weight(-1, spaces=self._space) + + def apply(self, x, mode): + self._check_input(x, mode) + if mode == self.TIMES: + return self._times(x) + return self._adjoint_times(x) diff --git a/nifty/operators/linear_operator.py b/nifty/operators/linear_operator.py index e26550de627b66255195f6d6d32c43bf5d5f44ed..a2c48dbe5d2dc9a935870462b454d9c3f9290d7b 100644 --- a/nifty/operators/linear_operator.py +++ b/nifty/operators/linear_operator.py @@ -16,7 +16,6 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. -from builtins import str import abc from ..utilities import NiftyMeta from ..field import Field @@ -25,21 +24,19 @@ from future.utils import with_metaclass class LinearOperator(with_metaclass( NiftyMeta, type('NewBase', (object,), {}))): - """NIFTY base class for linear operators. - The base NIFTY operator class is an abstract class from which - other specific operator subclasses are derived. + _validMode = (False, True, True, False, True, False, False, False, True) + _inverseMode = (0, 4, 8, 0, 1, 0, 0, 0, 2) + _inverseCapability = (0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15) + _adjointMode = (0, 2, 1, 0, 8, 0, 0, 0, 4) + _adjointCapability = (0, 2, 1, 3, 8, 10, 9, 11, 4, 6, 5, 7, 12, 14, 13, 15) + _addInverse = (0, 5, 10, 15, 5, 5, 15, 15, 10, 15, 10, 15, 15, 15, 15, 15) + def _dom(self, mode): + return self.domain if (mode & 9) else self.target - Attributes - ---------- - domain : DomainTuple - The domain on which the Operator's input Field lives. - target : DomainTuple - The domain in which the Operators result lives. - unitary : boolean - Indicates whether the Operator is unitary or not. - """ + def _tgt(self, mode): + return self.domain if (mode & 6) else self.target def __init__(self): pass @@ -65,138 +62,87 @@ class LinearOperator(with_metaclass( raise NotImplementedError @property - def unitary(self): - """ - unitary : boolean - States whether the Operator is unitary or not. - Since the majority of operators will not be unitary, this property - returns False, unless it is overridden in a subclass. - """ - return False + def TIMES(self): + return 1 - def __call__(self, x): - return self.times(x) + @property + def ADJOINT_TIMES(self): + return 2 - def times(self, x): - """ Applies the Operator to a given Field. + @property + def INVERSE_TIMES(self): + return 4 - Parameters - ---------- - x : Field - The input Field, living on the Operator's domain. + @property + def ADJOINT_INVERSE_TIMES(self): + return 8 - Returns - ------- - out : Field - The processed Field living on the Operator's target domain. - """ - self._check_input_compatibility(x) - return self._times(x) + @property + def INVERSE_ADJOINT_TIMES(self): + return 8 - def inverse_times(self, x): - """Applies the inverse Operator to a given Field. + @property + def inverse(self): + from .inverse_operator import InverseOperator + return InverseOperator(self) - Parameters - ---------- - x : Field - The input Field, living on the Operator's target domain + @property + def adjoint(self): + from .adjoint_operator import AdjointOperator + return AdjointOperator(self) - Returns - ------- - out : Field - The processed Field living on the Operator's domain. - """ - self._check_input_compatibility(x, inverse=True) - try: - y = self._inverse_times(x) - except NotImplementedError: - if self.unitary: - y = self._adjoint_times(x) - else: - raise - return y + def __mul__(self, other): + from .chain_operator import ChainOperator + return ChainOperator(self, other) - def adjoint_times(self, x): - """Applies the adjoint-Operator to a given Field. + def __add__(self, other): + from .sum_operator import SumOperator + return SumOperator(self, other) - Parameters - ---------- - x : Field - The input Field, living on the Operator's target domain + def __sub__(self, other): + from .sum_operator import SumOperator + return SumOperator(self, other, neg=True) - Returns - ------- - out : Field - The processed Field living on the Operator's domain. - """ - if self.unitary: - return self.inverse_times(x) - - self._check_input_compatibility(x, inverse=True) - try: - y = self._adjoint_times(x) - except NotImplementedError: - if self.unitary: - y = self._inverse_times(x) - else: - raise - return y + def supports(self, ops): + return False - def adjoint_inverse_times(self, x): - """ Applies the adjoint-inverse Operator to a given Field. - - Parameters - ---------- - x : Field - The input Field, living on the Operator's domain. - - Returns - ------- - out : Field - The processed Field living on the Operator's target domain. - - Notes - ----- - If the operator has an `inverse` then the inverse adjoint is identical - to the adjoint inverse. We provide both names for convenience. - """ - self._check_input_compatibility(x) - try: - y = self._adjoint_inverse_times(x) - except NotImplementedError: - if self.unitary: - y = self._times(x) - else: - raise - return y + @abc.abstractproperty + def capability(self): + raise NotImplementedError - def inverse_adjoint_times(self, x): - """Same as adjoint_inverse_times()""" - return self.adjoint_inverse_times(x) + @abc.abstractmethod + def apply(self, x, mode): + raise NotImplementedError + + def __call__(self, x): + return self.apply(x, self.TIMES) - def _times(self, x): - raise NotImplementedError( - "no generic instance method 'times'.") + def times(self, x): + return self.apply(x, self.TIMES) + + def inverse_times(self, x): + return self.apply(x, self.INVERSE_TIMES) + + def adjoint_times(self, x): + return self.apply(x, self.ADJOINT_TIMES) - def _adjoint_times(self, x): - raise NotImplementedError( - "no generic instance method 'adjoint_times'.") + def adjoint_inverse_times(self, x): + return self.apply(x, self.ADJOINT_INVERSE_TIMES) - def _inverse_times(self, x): - raise NotImplementedError( - "no generic instance method 'inverse_times'.") + def inverse_adjoint_times(self, x): + return self.apply(x, self.ADJOINT_INVERSE_TIMES) - def _adjoint_inverse_times(self, x): - raise NotImplementedError( - "no generic instance method 'adjoint_inverse_times'.") + def _check_mode(self, mode): + if not self._validMode[mode]: + raise ValueError("invalid operator mode specified") + if mode & self.capability == 0: + raise ValueError("requested operator mode is not supported") - def _check_input_compatibility(self, x, inverse=False): + def _check_input(self, x, mode): if not isinstance(x, Field): raise ValueError("supplied object is not a `Field`.") - if x.domain != (self.target if inverse else self.domain): - raise ValueError("The operator's and and field's domains " - "don't match.") - - def __repr__(self): - return str(self.__class__) + self._check_mode(mode) + if x.domain != self._dom(mode): + raise ValueError("The operator's and and field's domains " + "don't match.") diff --git a/nifty/operators/power_projection_operator.py b/nifty/operators/power_projection_operator.py index fde49ddabcb358c0fd7d0dc5916313d9161ecbe5..fc62557773072949c9afcb20e51ae0a1a39802a7 100644 --- a/nifty/operators/power_projection_operator.py +++ b/nifty/operators/power_projection_operator.py @@ -16,9 +16,8 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. -import numpy as np from .dof_projection_operator import DOFProjectionOperator -from .. import Field, DomainTuple, dobj +from .. import DomainTuple from ..utilities import infer_space from ..spaces import PowerSpace diff --git a/nifty/operators/response_operator.py b/nifty/operators/response_operator.py index 04b4fc7aa97319978516962a8df788b5c4360601..a348fd17300e0451afa72bd5470bda573c99cdc8 100644 --- a/nifty/operators/response_operator.py +++ b/nifty/operators/response_operator.py @@ -76,6 +76,10 @@ class ResponseOperator(LinearOperator): def target(self): return self._target + @property + def capability(self): + return self.TIMES | self.ADJOINT_TIMES + def _times(self, x): res = self._composed_kernel.times(x) res = self._composed_exposure.times(res) @@ -88,3 +92,8 @@ class ResponseOperator(LinearOperator): res = self._composed_exposure.adjoint_times(res) res = res.weight(power=-1) return self._composed_kernel.adjoint_times(res) + + def apply(self, x, mode): + self._check_input(x, mode) + return self._times(x) if mode == self.TIMES else self._adjoint_times(x) + diff --git a/nifty/operators/smoothness_operator.py b/nifty/operators/smoothness_operator.py index 610a238e664c0a7956de56287b64da31f1f60da4..4308427310b3d83d7f5db5fada664cb6c93a6220 100644 --- a/nifty/operators/smoothness_operator.py +++ b/nifty/operators/smoothness_operator.py @@ -39,10 +39,12 @@ class SmoothnessOperator(EndomorphicOperator): return self._laplace._domain @property - def self_adjoint(self): - return False + def capability(self): + return self.TIMES + + def apply(self, x, mode): + self._check_input(x, mode) - def _times(self, x): if self._strength == 0.: return x.zeros_like(x) result = self._laplace.adjoint_times(self._laplace(x)) diff --git a/nifty/operators/sum_operator.py b/nifty/operators/sum_operator.py new file mode 100644 index 0000000000000000000000000000000000000000..5ba82197bac1a4714483950a37926f5c2e4b0f7d --- /dev/null +++ b/nifty/operators/sum_operator.py @@ -0,0 +1,48 @@ +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# +# Copyright(C) 2013-2017 Max-Planck-Society +# +# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik +# and financially supported by the Studienstiftung des deutschen Volkes. + +from .linear_operator import LinearOperator + + +class SumOperator(LinearOperator): + def __init__(self, op1, op2, neg=False): + super(SumOperator, self).__init__() + if op1.domain != op2.domain or op1.target != op2.target: + raise ValueError("domain mismatch") + self._op1 = op1 + self._op2 = op2 + self._neg = bool(neg) + + @property + def domain(self): + return self._op1.domain + + @property + def target(self): + return self._op1.target + + @property + def capability(self): + return (self._op1.capability & self._op2.capability & + (self.TIMES | self.ADJOINT_TIMES)) + + def apply(self, x, mode): + self._check_mode(mode) + res1 = self._op1.apply(x, mode) + res2 = self._op2.apply(x, mode) + return res1 - res2 if self._neg else res1 + res2 diff --git a/test/test_operators/test_diagonal_operator.py b/test/test_operators/test_diagonal_operator.py index 1d2eac5429d4e1c7371d3eb6e11b42368f8f8296..a5ddaea7c439d0ed939f2014d9b8318f5ebc9604 100644 --- a/test/test_operators/test_diagonal_operator.py +++ b/test/test_operators/test_diagonal_operator.py @@ -16,10 +16,6 @@ class DiagonalOperator_Tests(unittest.TestCase): D = ift.DiagonalOperator(diag) if D.domain[0] != space: raise TypeError - if D.unitary: - raise TypeError - if not D.self_adjoint: - raise TypeError @expand(product(spaces)) def test_times_adjoint(self, space): diff --git a/test/test_operators/test_response_operator.py b/test/test_operators/test_response_operator.py index e2f39568159c0e4a9abe37d39ef741590ce41a63..bba4f9ddc6a233a78dadd368d44a927be127a9bd 100644 --- a/test/test_operators/test_response_operator.py +++ b/test/test_operators/test_response_operator.py @@ -14,8 +14,6 @@ class ResponseOperator_Tests(unittest.TestCase): exposure=[exposure]) if op.domain[0] != space: raise TypeError - if op.unitary: - raise ValueError @expand(product(spaces, [0., 5., 1.], [0., 1., .33])) def test_times_adjoint_times(self, space, sigma, exposure): diff --git a/test/test_operators/test_smoothing_operator.py b/test/test_operators/test_smoothing_operator.py index cbae4195bf470b00642da7cb400f25e71a2e20bb..6351d8f141a6037d479bb31d8d8f7676a5592b70 100644 --- a/test/test_operators/test_smoothing_operator.py +++ b/test/test_operators/test_smoothing_operator.py @@ -39,10 +39,6 @@ class SmoothingOperator_Tests(unittest.TestCase): op = ift.FFTSmoothingOperator(space, sigma=sigma) if op.domain[0] != space: raise TypeError - if op.unitary: - raise ValueError - if not op.self_adjoint: - raise ValueError @expand(product(spaces, [0., .5, 5.])) def test_adjoint_times(self, space, sigma):