Commit 932b59d3 authored by Jakob Knollmueller's avatar Jakob Knollmueller
Browse files

test

parent b35351ab
Pipeline #12453 failed with stage
in 4 minutes and 32 seconds
...@@ -12,17 +12,28 @@ np.random.seed(42) ...@@ -12,17 +12,28 @@ np.random.seed(42)
class AdjointFFTResponse(LinearOperator): class AdjointFFTResponse(LinearOperator):
def __init__(self, FFT, R, default_spaces=None): def __init__(self, FFT, R, default_spaces=None):
super(ResponseOperator, self).__init__(default_spaces) super(AdjointFFTResponse, self).__init__(default_spaces)
self._domain = FFT.target self._domain = FFT.target
self.target = R.target self._target = R.target
self.R = R self.R = R
self.FFT = FFT self.FFT = FFT
def _times(self, x): def _times(self, x, spaces=None):
return self.R(self.FFT.adjoint_times(x)) return self.R(self.FFT.adjoint_times(x))
def _adjoint_times(self, x): def _adjoint_times(self, x, spaces=None):
return self.FFT(self.R.adjoint_times(x)) return self.FFT(self.R.adjoint_times(x))
@property
def domain(self):
return self._domain
@property
def target(self):
return self._target
@property
def unitary(self):
return False
...@@ -79,11 +90,11 @@ if __name__ == "__main__": ...@@ -79,11 +90,11 @@ if __name__ == "__main__":
# callback=distance_measure, # callback=distance_measure,
# max_history_length=3) # max_history_length=3)
solution = energy.analytic_solution()
m0 = Field(s_space, val=1.)
energy = WienerFilterEnergy(position=m0, D=D, j=j) m0 = Field(s_space, val=1.)
energy = WienerFilterEnergy(position=m0, R=R, N=N, S=S)
solution = energy.analytic_solution()
(energy, convergence) = minimizer(energy) (energy, convergence) = minimizer(energy)
m = fft.adjoint_times(energy.position) m = fft.adjoint_times(energy.position)
......
...@@ -32,7 +32,7 @@ from config import dependency_injector,\ ...@@ -32,7 +32,7 @@ from config import dependency_injector,\
from d2o import distributed_data_object, d2o_librarian from d2o import distributed_data_object, d2o_librarian
from energies import *
from field import Field from field import Field
...@@ -44,6 +44,8 @@ from nifty_utilities import * ...@@ -44,6 +44,8 @@ from nifty_utilities import *
from field_types import * from field_types import *
from energies import *
from minimization import * from minimization import *
from spaces import * from spaces import *
......
...@@ -19,4 +19,4 @@ ...@@ -19,4 +19,4 @@
from energy import Energy from energy import Energy
from line_energy import LineEnergy from line_energy import LineEnergy
from memoization import memo from memoization import memo
from wiener_filter_energy import WienerFilterEnergy from wiener_filter_energy import WienerFilterEnergy
\ No newline at end of file
from .energy import Energy from .energy import Energy
from nifty.operators import WienerFilterCurvature from nifty.operators.curvature_operators import WienerFilterCurvature
class WienerFilterEnergy(Energy): class WienerFilterEnergy(Energy):
"""The Energy for the Wiener filter. """The Energy for the Wiener filter.
......
...@@ -40,4 +40,4 @@ from composed_operator import ComposedOperator ...@@ -40,4 +40,4 @@ from composed_operator import ComposedOperator
from response_operator import ResponseOperator from response_operator import ResponseOperator
from curvature_operators import WienerFilterCurvature from curvature_operators import *
from wiener_filter_curvature import WienerFilterCurvature from wiener_filter_curvature import WienerFilterCurvature
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment