diff --git a/demos/wiener_filter_hamiltonian.py b/demos/wiener_filter_hamiltonian.py index 9b95fdb917aeb3587a150d08e5eaf2840171f3f3..0d78d621fd5593bcbcd47fc7808e412686bd0f89 100644 --- a/demos/wiener_filter_hamiltonian.py +++ b/demos/wiener_filter_hamiltonian.py @@ -12,17 +12,28 @@ np.random.seed(42) class AdjointFFTResponse(LinearOperator): def __init__(self, FFT, R, default_spaces=None): - super(ResponseOperator, self).__init__(default_spaces) + super(AdjointFFTResponse, self).__init__(default_spaces) self._domain = FFT.target - self.target = R.target + self._target = R.target self.R = R self.FFT = FFT - def _times(self, x): + def _times(self, x, spaces=None): return self.R(self.FFT.adjoint_times(x)) - def _adjoint_times(self, x): + def _adjoint_times(self, x, spaces=None): return self.FFT(self.R.adjoint_times(x)) + @property + def domain(self): + return self._domain + + @property + def target(self): + return self._target + + @property + def unitary(self): + return False @@ -79,11 +90,11 @@ if __name__ == "__main__": # callback=distance_measure, # max_history_length=3) - solution = energy.analytic_solution() - m0 = Field(s_space, val=1.) - energy = WienerFilterEnergy(position=m0, D=D, j=j) + m0 = Field(s_space, val=1.) + energy = WienerFilterEnergy(position=m0, R=R, N=N, S=S) + solution = energy.analytic_solution() (energy, convergence) = minimizer(energy) m = fft.adjoint_times(energy.position) diff --git a/nifty/__init__.py b/nifty/__init__.py index 2c1a96d2f8ef31c13bd83c1be971d081d77768e6..fb9012cf37b42dd51545ac636016927f445eb76d 100644 --- a/nifty/__init__.py +++ b/nifty/__init__.py @@ -32,7 +32,7 @@ from config import dependency_injector,\ from d2o import distributed_data_object, d2o_librarian -from energies import * + from field import Field @@ -44,6 +44,8 @@ from nifty_utilities import * from field_types import * +from energies import * + from minimization import * from spaces import * diff --git a/nifty/energies/__init__.py b/nifty/energies/__init__.py index 380252558bba9feb924b457164c74d1737d97964..0ebaa4403e6b3cce66435f50dfdf580b94f25ea7 100644 --- a/nifty/energies/__init__.py +++ b/nifty/energies/__init__.py @@ -19,4 +19,4 @@ from energy import Energy from line_energy import LineEnergy from memoization import memo -from wiener_filter_energy import WienerFilterEnergy \ No newline at end of file +from wiener_filter_energy import WienerFilterEnergy diff --git a/nifty/energies/wiener_filter_energy.py b/nifty/energies/wiener_filter_energy.py index ed2d0d57dc168866bea65923d8dd78056b727080..52714fe221a65f88676c9b664b9fb584a351e302 100644 --- a/nifty/energies/wiener_filter_energy.py +++ b/nifty/energies/wiener_filter_energy.py @@ -1,5 +1,5 @@ from .energy import Energy -from nifty.operators import WienerFilterCurvature +from nifty.operators.curvature_operators import WienerFilterCurvature class WienerFilterEnergy(Energy): """The Energy for the Wiener filter. diff --git a/nifty/operators/__init__.py b/nifty/operators/__init__.py index b64ee506ada2e52aa5c5397b4e0cc5f23f3c7f13..6e066d944672d363d047daebf0c360164a59cf58 100644 --- a/nifty/operators/__init__.py +++ b/nifty/operators/__init__.py @@ -40,4 +40,4 @@ from composed_operator import ComposedOperator from response_operator import ResponseOperator -from curvature_operators import WienerFilterCurvature +from curvature_operators import * diff --git a/nifty/operators/curvature_operators/__init__.py b/nifty/operators/curvature_operators/__init__.py index 002face2b7d298be081daa5808b5ad32aba8b7e3..3af162be136ce5853aac3de2c0f714d1b6ae9e8c 100644 --- a/nifty/operators/curvature_operators/__init__.py +++ b/nifty/operators/curvature_operators/__init__.py @@ -1 +1 @@ -from wiener_filter_curvature import WienerFilterCurvature \ No newline at end of file +from wiener_filter_curvature import WienerFilterCurvature