Commit 8f069aa7 authored by theos's avatar theos

Added kwargs to LinearOperator times methods.

Raised version to 3.0.1.
parent 5af4d1da
from nifty import *
#import plotly.offline as pl
#import plotly.graph_objs as go
import plotly.offline as pl
import plotly.graph_objs as go
from mpi4py import MPI
comm = MPI.COMM_WORLD
......@@ -12,11 +12,14 @@ if __name__ == "__main__":
distribution_strategy = 'fftw'
# Setting up the geometry
s_space = RGSpace([512, 512], dtype=np.float64)
fft = FFTOperator(s_space)
h_space = fft.target[0]
p_space = PowerSpace(h_space, distribution_strategy=distribution_strategy)
# Creating the mock data
pow_spec = (lambda k: 42 / (k + 1) ** 3)
S = create_power_operator(h_space, power_spectrum=pow_spec,
......@@ -37,6 +40,8 @@ if __name__ == "__main__":
mean=0)
d = R(ss) + n
# Wiener filter
j = R.adjoint_times(N.inverse_times(d))
D = PropagatorOperator(S=S, N=N, R=R)
......@@ -45,9 +50,8 @@ if __name__ == "__main__":
d_data = d.val.get_full_data().real
m_data = m.val.get_full_data().real
ss_data = ss.val.get_full_data().real
if rank == 0:
pl.plot([go.Heatmap(z=d_data)], filename='data.html')
pl.plot([go.Heatmap(z=m_data)], filename='map.html')
pl.plot([go.Heatmap(z=ss_data)], filename='map_orig.html')
# if rank == 0:
# pl.plot([go.Heatmap(z=d_data)], filename='data.html')
# pl.plot([go.Heatmap(z=m_data)], filename='map.html')
# pl.plot([go.Heatmap(z=ss_data)], filename='map_orig.html')
#
......@@ -71,25 +71,25 @@ class LinearOperator(Loggable, object):
def __call__(self, *args, **kwargs):
return self.times(*args, **kwargs)
def times(self, x, spaces=None, types=None):
def times(self, x, spaces=None, types=None, **kwargs):
spaces, types = self._check_input_compatibility(x, spaces, types)
if not self.implemented:
x = x.weight(spaces=spaces)
y = self._times(x, spaces, types)
y = self._times(x, spaces, types, **kwargs)
return y
def inverse_times(self, x, spaces=None, types=None):
def inverse_times(self, x, spaces=None, types=None, **kwargs):
spaces, types = self._check_input_compatibility(x, spaces, types,
inverse=True)
y = self._inverse_times(x, spaces, types)
y = self._inverse_times(x, spaces, types, **kwargs)
if not self.implemented:
y = y.weight(power=-1, spaces=spaces)
return y
def adjoint_times(self, x, spaces=None, types=None):
def adjoint_times(self, x, spaces=None, types=None, **kwargs):
if self.unitary:
return self.inverse_times(x, spaces, types)
......@@ -98,23 +98,23 @@ class LinearOperator(Loggable, object):
if not self.implemented:
x = x.weight(spaces=spaces)
y = self._adjoint_times(x, spaces, types)
y = self._adjoint_times(x, spaces, types, **kwargs)
return y
def adjoint_inverse_times(self, x, spaces=None, types=None):
def adjoint_inverse_times(self, x, spaces=None, types=None, **kwargs):
if self.unitary:
return self.times(x, spaces, types)
spaces, types = self._check_input_compatibility(x, spaces, types)
y = self._adjoint_inverse_times(x, spaces, types)
y = self._adjoint_inverse_times(x, spaces, types, **kwargs)
if not self.implemented:
y = y.weight(power=-1, spaces=spaces)
return y
def inverse_adjoint_times(self, x, spaces=None, types=None):
def inverse_adjoint_times(self, x, spaces=None, types=None, **kwargs):
if self.unitary:
return self.times(x, spaces, types)
return self.times(x, spaces, types, **kwargs)
spaces, types = self._check_input_compatibility(x, spaces, types)
......
......@@ -4,4 +4,4 @@
# 1) we don't load dependencies by storing it in __init__.py
# 2) we can import it in setup.py for the same reason
# 3) we can import it into your module module
__version__ = '3.0.0a3'
\ No newline at end of file
__version__ = '3.0.1'
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment