Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
8f069aa7
Commit
8f069aa7
authored
Nov 16, 2016
by
theos
Browse files
Added kwargs to LinearOperator times methods.
Raised version to 3.0.1.
parent
5af4d1da
Changes
3
Hide whitespace changes
Inline
Side-by-side
demos/wiener_filter.py
View file @
8f069aa7
from
nifty
import
*
#
import plotly.offline as pl
#
import plotly.graph_objs as go
import
plotly.offline
as
pl
import
plotly.graph_objs
as
go
from
mpi4py
import
MPI
comm
=
MPI
.
COMM_WORLD
...
...
@@ -12,11 +12,14 @@ if __name__ == "__main__":
distribution_strategy
=
'fftw'
# Setting up the geometry
s_space
=
RGSpace
([
512
,
512
],
dtype
=
np
.
float64
)
fft
=
FFTOperator
(
s_space
)
h_space
=
fft
.
target
[
0
]
p_space
=
PowerSpace
(
h_space
,
distribution_strategy
=
distribution_strategy
)
# Creating the mock data
pow_spec
=
(
lambda
k
:
42
/
(
k
+
1
)
**
3
)
S
=
create_power_operator
(
h_space
,
power_spectrum
=
pow_spec
,
...
...
@@ -37,6 +40,8 @@ if __name__ == "__main__":
mean
=
0
)
d
=
R
(
ss
)
+
n
# Wiener filter
j
=
R
.
adjoint_times
(
N
.
inverse_times
(
d
))
D
=
PropagatorOperator
(
S
=
S
,
N
=
N
,
R
=
R
)
...
...
@@ -45,9 +50,8 @@ if __name__ == "__main__":
d_data
=
d
.
val
.
get_full_data
().
real
m_data
=
m
.
val
.
get_full_data
().
real
ss_data
=
ss
.
val
.
get_full_data
().
real
if
rank
==
0
:
pl
.
plot
([
go
.
Heatmap
(
z
=
d_data
)],
filename
=
'data.html'
)
pl
.
plot
([
go
.
Heatmap
(
z
=
m_data
)],
filename
=
'map.html'
)
pl
.
plot
([
go
.
Heatmap
(
z
=
ss_data
)],
filename
=
'map_orig.html'
)
# if rank == 0:
# pl.plot([go.Heatmap(z=d_data)], filename='data.html')
# pl.plot([go.Heatmap(z=m_data)], filename='map.html')
# pl.plot([go.Heatmap(z=ss_data)], filename='map_orig.html')
#
nifty/operators/linear_operator/linear_operator.py
View file @
8f069aa7
...
...
@@ -71,25 +71,25 @@ class LinearOperator(Loggable, object):
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
self
.
times
(
*
args
,
**
kwargs
)
def
times
(
self
,
x
,
spaces
=
None
,
types
=
None
):
def
times
(
self
,
x
,
spaces
=
None
,
types
=
None
,
**
kwargs
):
spaces
,
types
=
self
.
_check_input_compatibility
(
x
,
spaces
,
types
)
if
not
self
.
implemented
:
x
=
x
.
weight
(
spaces
=
spaces
)
y
=
self
.
_times
(
x
,
spaces
,
types
)
y
=
self
.
_times
(
x
,
spaces
,
types
,
**
kwargs
)
return
y
def
inverse_times
(
self
,
x
,
spaces
=
None
,
types
=
None
):
def
inverse_times
(
self
,
x
,
spaces
=
None
,
types
=
None
,
**
kwargs
):
spaces
,
types
=
self
.
_check_input_compatibility
(
x
,
spaces
,
types
,
inverse
=
True
)
y
=
self
.
_inverse_times
(
x
,
spaces
,
types
)
y
=
self
.
_inverse_times
(
x
,
spaces
,
types
,
**
kwargs
)
if
not
self
.
implemented
:
y
=
y
.
weight
(
power
=-
1
,
spaces
=
spaces
)
return
y
def
adjoint_times
(
self
,
x
,
spaces
=
None
,
types
=
None
):
def
adjoint_times
(
self
,
x
,
spaces
=
None
,
types
=
None
,
**
kwargs
):
if
self
.
unitary
:
return
self
.
inverse_times
(
x
,
spaces
,
types
)
...
...
@@ -98,23 +98,23 @@ class LinearOperator(Loggable, object):
if
not
self
.
implemented
:
x
=
x
.
weight
(
spaces
=
spaces
)
y
=
self
.
_adjoint_times
(
x
,
spaces
,
types
)
y
=
self
.
_adjoint_times
(
x
,
spaces
,
types
,
**
kwargs
)
return
y
def
adjoint_inverse_times
(
self
,
x
,
spaces
=
None
,
types
=
None
):
def
adjoint_inverse_times
(
self
,
x
,
spaces
=
None
,
types
=
None
,
**
kwargs
):
if
self
.
unitary
:
return
self
.
times
(
x
,
spaces
,
types
)
spaces
,
types
=
self
.
_check_input_compatibility
(
x
,
spaces
,
types
)
y
=
self
.
_adjoint_inverse_times
(
x
,
spaces
,
types
)
y
=
self
.
_adjoint_inverse_times
(
x
,
spaces
,
types
,
**
kwargs
)
if
not
self
.
implemented
:
y
=
y
.
weight
(
power
=-
1
,
spaces
=
spaces
)
return
y
def
inverse_adjoint_times
(
self
,
x
,
spaces
=
None
,
types
=
None
):
def
inverse_adjoint_times
(
self
,
x
,
spaces
=
None
,
types
=
None
,
**
kwargs
):
if
self
.
unitary
:
return
self
.
times
(
x
,
spaces
,
types
)
return
self
.
times
(
x
,
spaces
,
types
,
**
kwargs
)
spaces
,
types
=
self
.
_check_input_compatibility
(
x
,
spaces
,
types
)
...
...
nifty/version.py
View file @
8f069aa7
...
...
@@ -4,4 +4,4 @@
# 1) we don't load dependencies by storing it in __init__.py
# 2) we can import it in setup.py for the same reason
# 3) we can import it into your module module
__version__
=
'3.0.0a3'
\ No newline at end of file
__version__
=
'3.0.1'
\ No newline at end of file
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment