From 932b59d3147bf010523171df978f3084d3d51181 Mon Sep 17 00:00:00 2001 From: "Knollmueller, Jakob (kjako)" <jakob@knollmueller.de> Date: Mon, 15 May 2017 15:48:34 +0200 Subject: [PATCH] test --- demos/wiener_filter_hamiltonian.py | 25 +++++++++++++------ nifty/__init__.py | 4 ++- nifty/energies/__init__.py | 2 +- nifty/energies/wiener_filter_energy.py | 2 +- nifty/operators/__init__.py | 2 +- .../operators/curvature_operators/__init__.py | 2 +- 6 files changed, 25 insertions(+), 12 deletions(-) diff --git a/demos/wiener_filter_hamiltonian.py b/demos/wiener_filter_hamiltonian.py index 9b95fdb91..0d78d621f 100644 --- a/demos/wiener_filter_hamiltonian.py +++ b/demos/wiener_filter_hamiltonian.py @@ -12,17 +12,28 @@ np.random.seed(42) class AdjointFFTResponse(LinearOperator): def __init__(self, FFT, R, default_spaces=None): - super(ResponseOperator, self).__init__(default_spaces) + super(AdjointFFTResponse, self).__init__(default_spaces) self._domain = FFT.target - self.target = R.target + self._target = R.target self.R = R self.FFT = FFT - def _times(self, x): + def _times(self, x, spaces=None): return self.R(self.FFT.adjoint_times(x)) - def _adjoint_times(self, x): + def _adjoint_times(self, x, spaces=None): return self.FFT(self.R.adjoint_times(x)) + @property + def domain(self): + return self._domain + + @property + def target(self): + return self._target + + @property + def unitary(self): + return False @@ -79,11 +90,11 @@ if __name__ == "__main__": # callback=distance_measure, # max_history_length=3) - solution = energy.analytic_solution() - m0 = Field(s_space, val=1.) - energy = WienerFilterEnergy(position=m0, D=D, j=j) + m0 = Field(s_space, val=1.) + energy = WienerFilterEnergy(position=m0, R=R, N=N, S=S) + solution = energy.analytic_solution() (energy, convergence) = minimizer(energy) m = fft.adjoint_times(energy.position) diff --git a/nifty/__init__.py b/nifty/__init__.py index 2c1a96d2f..fb9012cf3 100644 --- a/nifty/__init__.py +++ b/nifty/__init__.py @@ -32,7 +32,7 @@ from config import dependency_injector,\ from d2o import distributed_data_object, d2o_librarian -from energies import * + from field import Field @@ -44,6 +44,8 @@ from nifty_utilities import * from field_types import * +from energies import * + from minimization import * from spaces import * diff --git a/nifty/energies/__init__.py b/nifty/energies/__init__.py index 380252558..0ebaa4403 100644 --- a/nifty/energies/__init__.py +++ b/nifty/energies/__init__.py @@ -19,4 +19,4 @@ from energy import Energy from line_energy import LineEnergy from memoization import memo -from wiener_filter_energy import WienerFilterEnergy \ No newline at end of file +from wiener_filter_energy import WienerFilterEnergy diff --git a/nifty/energies/wiener_filter_energy.py b/nifty/energies/wiener_filter_energy.py index ed2d0d57d..52714fe22 100644 --- a/nifty/energies/wiener_filter_energy.py +++ b/nifty/energies/wiener_filter_energy.py @@ -1,5 +1,5 @@ from .energy import Energy -from nifty.operators import WienerFilterCurvature +from nifty.operators.curvature_operators import WienerFilterCurvature class WienerFilterEnergy(Energy): """The Energy for the Wiener filter. diff --git a/nifty/operators/__init__.py b/nifty/operators/__init__.py index b64ee506a..6e066d944 100644 --- a/nifty/operators/__init__.py +++ b/nifty/operators/__init__.py @@ -40,4 +40,4 @@ from composed_operator import ComposedOperator from response_operator import ResponseOperator -from curvature_operators import WienerFilterCurvature +from curvature_operators import * diff --git a/nifty/operators/curvature_operators/__init__.py b/nifty/operators/curvature_operators/__init__.py index 002face2b..3af162be1 100644 --- a/nifty/operators/curvature_operators/__init__.py +++ b/nifty/operators/curvature_operators/__init__.py @@ -1 +1 @@ -from wiener_filter_curvature import WienerFilterCurvature \ No newline at end of file +from wiener_filter_curvature import WienerFilterCurvature -- GitLab