Commit 8f069aa7 authored by theos's avatar theos

Added kwargs to LinearOperator times methods.

Raised version to 3.0.1.
parent 5af4d1da
from nifty import * from nifty import *
#import plotly.offline as pl import plotly.offline as pl
#import plotly.graph_objs as go import plotly.graph_objs as go
from mpi4py import MPI from mpi4py import MPI
comm = MPI.COMM_WORLD comm = MPI.COMM_WORLD
...@@ -12,11 +12,14 @@ if __name__ == "__main__": ...@@ -12,11 +12,14 @@ if __name__ == "__main__":
distribution_strategy = 'fftw' distribution_strategy = 'fftw'
# Setting up the geometry
s_space = RGSpace([512, 512], dtype=np.float64) s_space = RGSpace([512, 512], dtype=np.float64)
fft = FFTOperator(s_space) fft = FFTOperator(s_space)
h_space = fft.target[0] h_space = fft.target[0]
p_space = PowerSpace(h_space, distribution_strategy=distribution_strategy) p_space = PowerSpace(h_space, distribution_strategy=distribution_strategy)
# Creating the mock data
pow_spec = (lambda k: 42 / (k + 1) ** 3) pow_spec = (lambda k: 42 / (k + 1) ** 3)
S = create_power_operator(h_space, power_spectrum=pow_spec, S = create_power_operator(h_space, power_spectrum=pow_spec,
...@@ -37,6 +40,8 @@ if __name__ == "__main__": ...@@ -37,6 +40,8 @@ if __name__ == "__main__":
mean=0) mean=0)
d = R(ss) + n d = R(ss) + n
# Wiener filter
j = R.adjoint_times(N.inverse_times(d)) j = R.adjoint_times(N.inverse_times(d))
D = PropagatorOperator(S=S, N=N, R=R) D = PropagatorOperator(S=S, N=N, R=R)
...@@ -45,9 +50,8 @@ if __name__ == "__main__": ...@@ -45,9 +50,8 @@ if __name__ == "__main__":
d_data = d.val.get_full_data().real d_data = d.val.get_full_data().real
m_data = m.val.get_full_data().real m_data = m.val.get_full_data().real
ss_data = ss.val.get_full_data().real ss_data = ss.val.get_full_data().real
if rank == 0:
pl.plot([go.Heatmap(z=d_data)], filename='data.html')
pl.plot([go.Heatmap(z=m_data)], filename='map.html')
pl.plot([go.Heatmap(z=ss_data)], filename='map_orig.html')
# if rank == 0:
# pl.plot([go.Heatmap(z=d_data)], filename='data.html')
# pl.plot([go.Heatmap(z=m_data)], filename='map.html')
# pl.plot([go.Heatmap(z=ss_data)], filename='map_orig.html')
#
...@@ -71,25 +71,25 @@ class LinearOperator(Loggable, object): ...@@ -71,25 +71,25 @@ class LinearOperator(Loggable, object):
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.times(*args, **kwargs) return self.times(*args, **kwargs)
def times(self, x, spaces=None, types=None): def times(self, x, spaces=None, types=None, **kwargs):
spaces, types = self._check_input_compatibility(x, spaces, types) spaces, types = self._check_input_compatibility(x, spaces, types)
if not self.implemented: if not self.implemented:
x = x.weight(spaces=spaces) x = x.weight(spaces=spaces)
y = self._times(x, spaces, types) y = self._times(x, spaces, types, **kwargs)
return y return y
def inverse_times(self, x, spaces=None, types=None): def inverse_times(self, x, spaces=None, types=None, **kwargs):
spaces, types = self._check_input_compatibility(x, spaces, types, spaces, types = self._check_input_compatibility(x, spaces, types,
inverse=True) inverse=True)
y = self._inverse_times(x, spaces, types) y = self._inverse_times(x, spaces, types, **kwargs)
if not self.implemented: if not self.implemented:
y = y.weight(power=-1, spaces=spaces) y = y.weight(power=-1, spaces=spaces)
return y return y
def adjoint_times(self, x, spaces=None, types=None): def adjoint_times(self, x, spaces=None, types=None, **kwargs):
if self.unitary: if self.unitary:
return self.inverse_times(x, spaces, types) return self.inverse_times(x, spaces, types)
...@@ -98,23 +98,23 @@ class LinearOperator(Loggable, object): ...@@ -98,23 +98,23 @@ class LinearOperator(Loggable, object):
if not self.implemented: if not self.implemented:
x = x.weight(spaces=spaces) x = x.weight(spaces=spaces)
y = self._adjoint_times(x, spaces, types) y = self._adjoint_times(x, spaces, types, **kwargs)
return y return y
def adjoint_inverse_times(self, x, spaces=None, types=None): def adjoint_inverse_times(self, x, spaces=None, types=None, **kwargs):
if self.unitary: if self.unitary:
return self.times(x, spaces, types) return self.times(x, spaces, types)
spaces, types = self._check_input_compatibility(x, spaces, types) spaces, types = self._check_input_compatibility(x, spaces, types)
y = self._adjoint_inverse_times(x, spaces, types) y = self._adjoint_inverse_times(x, spaces, types, **kwargs)
if not self.implemented: if not self.implemented:
y = y.weight(power=-1, spaces=spaces) y = y.weight(power=-1, spaces=spaces)
return y return y
def inverse_adjoint_times(self, x, spaces=None, types=None): def inverse_adjoint_times(self, x, spaces=None, types=None, **kwargs):
if self.unitary: if self.unitary:
return self.times(x, spaces, types) return self.times(x, spaces, types, **kwargs)
spaces, types = self._check_input_compatibility(x, spaces, types) spaces, types = self._check_input_compatibility(x, spaces, types)
......
...@@ -4,4 +4,4 @@ ...@@ -4,4 +4,4 @@
# 1) we don't load dependencies by storing it in __init__.py # 1) we don't load dependencies by storing it in __init__.py
# 2) we can import it in setup.py for the same reason # 2) we can import it in setup.py for the same reason
# 3) we can import it into your module module # 3) we can import it into your module module
__version__ = '3.0.0a3' __version__ = '3.0.1'
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment