Skip to content
Snippets Groups Projects
Commit 8f069aa7 authored by theos's avatar theos
Browse files

Added kwargs to LinearOperator times methods.

Raised version to 3.0.1.
parent 5af4d1da
No related branches found
No related tags found
No related merge requests found
from nifty import * from nifty import *
#import plotly.offline as pl import plotly.offline as pl
#import plotly.graph_objs as go import plotly.graph_objs as go
from mpi4py import MPI from mpi4py import MPI
comm = MPI.COMM_WORLD comm = MPI.COMM_WORLD
...@@ -12,11 +12,14 @@ if __name__ == "__main__": ...@@ -12,11 +12,14 @@ if __name__ == "__main__":
distribution_strategy = 'fftw' distribution_strategy = 'fftw'
# Setting up the geometry
s_space = RGSpace([512, 512], dtype=np.float64) s_space = RGSpace([512, 512], dtype=np.float64)
fft = FFTOperator(s_space) fft = FFTOperator(s_space)
h_space = fft.target[0] h_space = fft.target[0]
p_space = PowerSpace(h_space, distribution_strategy=distribution_strategy) p_space = PowerSpace(h_space, distribution_strategy=distribution_strategy)
# Creating the mock data
pow_spec = (lambda k: 42 / (k + 1) ** 3) pow_spec = (lambda k: 42 / (k + 1) ** 3)
S = create_power_operator(h_space, power_spectrum=pow_spec, S = create_power_operator(h_space, power_spectrum=pow_spec,
...@@ -37,6 +40,8 @@ if __name__ == "__main__": ...@@ -37,6 +40,8 @@ if __name__ == "__main__":
mean=0) mean=0)
d = R(ss) + n d = R(ss) + n
# Wiener filter
j = R.adjoint_times(N.inverse_times(d)) j = R.adjoint_times(N.inverse_times(d))
D = PropagatorOperator(S=S, N=N, R=R) D = PropagatorOperator(S=S, N=N, R=R)
...@@ -45,9 +50,8 @@ if __name__ == "__main__": ...@@ -45,9 +50,8 @@ if __name__ == "__main__":
d_data = d.val.get_full_data().real d_data = d.val.get_full_data().real
m_data = m.val.get_full_data().real m_data = m.val.get_full_data().real
ss_data = ss.val.get_full_data().real ss_data = ss.val.get_full_data().real
if rank == 0:
pl.plot([go.Heatmap(z=d_data)], filename='data.html')
pl.plot([go.Heatmap(z=m_data)], filename='map.html')
pl.plot([go.Heatmap(z=ss_data)], filename='map_orig.html')
# if rank == 0:
# pl.plot([go.Heatmap(z=d_data)], filename='data.html')
# pl.plot([go.Heatmap(z=m_data)], filename='map.html')
# pl.plot([go.Heatmap(z=ss_data)], filename='map_orig.html')
#
...@@ -71,25 +71,25 @@ class LinearOperator(Loggable, object): ...@@ -71,25 +71,25 @@ class LinearOperator(Loggable, object):
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.times(*args, **kwargs) return self.times(*args, **kwargs)
def times(self, x, spaces=None, types=None): def times(self, x, spaces=None, types=None, **kwargs):
spaces, types = self._check_input_compatibility(x, spaces, types) spaces, types = self._check_input_compatibility(x, spaces, types)
if not self.implemented: if not self.implemented:
x = x.weight(spaces=spaces) x = x.weight(spaces=spaces)
y = self._times(x, spaces, types) y = self._times(x, spaces, types, **kwargs)
return y return y
def inverse_times(self, x, spaces=None, types=None): def inverse_times(self, x, spaces=None, types=None, **kwargs):
spaces, types = self._check_input_compatibility(x, spaces, types, spaces, types = self._check_input_compatibility(x, spaces, types,
inverse=True) inverse=True)
y = self._inverse_times(x, spaces, types) y = self._inverse_times(x, spaces, types, **kwargs)
if not self.implemented: if not self.implemented:
y = y.weight(power=-1, spaces=spaces) y = y.weight(power=-1, spaces=spaces)
return y return y
def adjoint_times(self, x, spaces=None, types=None): def adjoint_times(self, x, spaces=None, types=None, **kwargs):
if self.unitary: if self.unitary:
return self.inverse_times(x, spaces, types) return self.inverse_times(x, spaces, types)
...@@ -98,23 +98,23 @@ class LinearOperator(Loggable, object): ...@@ -98,23 +98,23 @@ class LinearOperator(Loggable, object):
if not self.implemented: if not self.implemented:
x = x.weight(spaces=spaces) x = x.weight(spaces=spaces)
y = self._adjoint_times(x, spaces, types) y = self._adjoint_times(x, spaces, types, **kwargs)
return y return y
def adjoint_inverse_times(self, x, spaces=None, types=None): def adjoint_inverse_times(self, x, spaces=None, types=None, **kwargs):
if self.unitary: if self.unitary:
return self.times(x, spaces, types) return self.times(x, spaces, types)
spaces, types = self._check_input_compatibility(x, spaces, types) spaces, types = self._check_input_compatibility(x, spaces, types)
y = self._adjoint_inverse_times(x, spaces, types) y = self._adjoint_inverse_times(x, spaces, types, **kwargs)
if not self.implemented: if not self.implemented:
y = y.weight(power=-1, spaces=spaces) y = y.weight(power=-1, spaces=spaces)
return y return y
def inverse_adjoint_times(self, x, spaces=None, types=None): def inverse_adjoint_times(self, x, spaces=None, types=None, **kwargs):
if self.unitary: if self.unitary:
return self.times(x, spaces, types) return self.times(x, spaces, types, **kwargs)
spaces, types = self._check_input_compatibility(x, spaces, types) spaces, types = self._check_input_compatibility(x, spaces, types)
......
...@@ -4,4 +4,4 @@ ...@@ -4,4 +4,4 @@
# 1) we don't load dependencies by storing it in __init__.py # 1) we don't load dependencies by storing it in __init__.py
# 2) we can import it in setup.py for the same reason # 2) we can import it in setup.py for the same reason
# 3) we can import it into your module module # 3) we can import it into your module module
__version__ = '3.0.0a3' __version__ = '3.0.1'
\ No newline at end of file \ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment