Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
4cb89177
Commit
4cb89177
authored
May 16, 2017
by
Jakob Knollmueller
Browse files
Wiener Filter advanced
parent
f8383a88
Changes
7
Show whitespace changes
Inline
Side-by-side
demos/wiener_filter_
hamiltonian
.py
→
demos/wiener_filter_
advanced
.py
View file @
4cb89177
...
@@ -41,27 +41,34 @@ if __name__ == "__main__":
...
@@ -41,27 +41,34 @@ if __name__ == "__main__":
distribution_strategy
=
'not'
distribution_strategy
=
'not'
# Set up spaces and fft transformation
# Set up position space
s_space
=
RGSpace
([
512
,
512
])
s_space
=
RGSpace
([
128
,
129
])
# s_space = HPSpace(32)
# Define harmonic transformation and associated harmonic space
fft
=
FFTOperator
(
s_space
)
fft
=
FFTOperator
(
s_space
)
h_space
=
fft
.
target
[
0
]
h_space
=
fft
.
target
[
0
]
# Setting up power space
p_space
=
PowerSpace
(
h_space
,
distribution_strategy
=
distribution_strategy
)
p_space
=
PowerSpace
(
h_space
,
distribution_strategy
=
distribution_strategy
)
#
create the field instances and power
operator
#
Choosing the prior correlation structure and defining correlation
operator
pow_spec
=
(
lambda
k
:
(
42
/
(
k
+
1
)
**
3
))
pow_spec
=
(
lambda
k
:
(
42
/
(
k
+
1
)
**
3
))
S
=
create_power_operator
(
h_space
,
power_spectrum
=
pow_spec
,
S
=
create_power_operator
(
h_space
,
power_spectrum
=
pow_spec
,
distribution_strategy
=
distribution_strategy
)
distribution_strategy
=
distribution_strategy
)
# Drawing a sample sh from the prior distribution in harmonic space
sp
=
Field
(
p_space
,
val
=
lambda
z
:
pow_spec
(
z
)
**
(
1.
/
2
),
sp
=
Field
(
p_space
,
val
=
lambda
z
:
pow_spec
(
z
)
**
(
1.
/
2
),
distribution_strategy
=
distribution_strategy
)
distribution_strategy
=
distribution_strategy
)
sh
=
sp
.
power_synthesize
(
real_signal
=
True
)
sh
=
sp
.
power_synthesize
(
real_signal
=
True
)
ss
=
fft
.
inverse_times
(
sh
)
# model the measurement process
Instrument
=
SmoothingOperator
(
s_space
,
sigma
=
0.01
)
# Choosing the measurement instrument
Instrument
=
SmoothingOperator
(
s_space
,
sigma
=
0.05
)
# Instrument = DiagonalOperator(s_space, diagonal=1.)
# Instrument = DiagonalOperator(s_space, diagonal=1.)
# Instrument._diagonal.val[200:400, 200:400] = 0
# Instrument._diagonal.val[200:400, 200:400] = 0
#Adding a harmonic transformation to the instrument
R
=
AdjointFFTResponse
(
fft
,
Instrument
)
R
=
AdjointFFTResponse
(
fft
,
Instrument
)
signal_to_noise
=
1
signal_to_noise
=
1
N
=
DiagonalOperator
(
s_space
,
diagonal
=
ss
.
var
()
/
signal_to_noise
,
bare
=
True
)
N
=
DiagonalOperator
(
s_space
,
diagonal
=
ss
.
var
()
/
signal_to_noise
,
bare
=
True
)
...
@@ -70,34 +77,49 @@ if __name__ == "__main__":
...
@@ -70,34 +77,49 @@ if __name__ == "__main__":
std
=
ss
.
std
()
/
np
.
sqrt
(
signal_to_noise
),
std
=
ss
.
std
()
/
np
.
sqrt
(
signal_to_noise
),
mean
=
0
)
mean
=
0
)
#
c
reate mock data
#
C
reat
ing th
e mock data
d
=
R
(
sh
)
+
n
d
=
R
(
sh
)
+
n
def
distance_measure
(
energy
,
iteration
):
# Choosing the minimization strategy
def
convergence_measure
(
energy
,
iteration
):
# returns current energy
x
=
energy
.
value
x
=
energy
.
value
print
(
x
,
iteration
)
print
(
x
,
iteration
)
# minimizer = SteepestDescent(convergence_tolerance=0,
# minimizer = SteepestDescent(convergence_tolerance=0,
# iteration_limit=50,
# iteration_limit=50,
# callback=
dista
nce_measure)
# callback=
converge
nce_measure)
minimizer
=
RelaxedNewton
(
convergence_tolerance
=
0
,
minimizer
=
RelaxedNewton
(
convergence_tolerance
=
0
,
iteration_limit
=
2
,
iteration_limit
=
10
,
callback
=
distance_measure
)
callback
=
convergence_measure
)
#
# minimizer = VL_BFGS(convergence_tolerance=0,
# minimizer = VL_BFGS(convergence_tolerance=0,
# iteration_limit=50,
# iteration_limit=500,
# callback=distance_measure,
# callback=convergence_measure,
# max_history_length=3)
# max_history_length=3)
#
#
m0
=
Field
(
s_space
,
val
=
1.
)
# Setting starting position
energy
=
WienerFilterEnergy
(
position
=
m0
,
R
=
R
,
N
=
N
,
S
=
S
)
m0
=
Field
(
h_space
,
val
=
1.
)
# Initializing the Wiener Filter energy
energy
=
WienerFilterEnergy
(
position
=
m0
,
d
=
d
,
R
=
R
,
N
=
N
,
S
=
S
)
# Solving the problem analytically
solution
=
energy
.
analytic_solution
()
solution
=
energy
.
analytic_solution
()
# Solving the problem with chosen minimization strategy
(
energy
,
convergence
)
=
minimizer
(
energy
)
(
energy
,
convergence
)
=
minimizer
(
energy
)
# Transforming fields to position space for plotting
ss
=
fft
.
adjoint_times
(
sh
)
m
=
fft
.
adjoint_times
(
energy
.
position
)
m
=
fft
.
adjoint_times
(
energy
.
position
)
m_wf
=
fft
.
adjoint_times
(
solution
.
position
)
# Plotting
d_data
=
d
.
val
.
get_full_data
().
real
d_data
=
d
.
val
.
get_full_data
().
real
if
rank
==
0
:
if
rank
==
0
:
...
@@ -112,20 +134,12 @@ if __name__ == "__main__":
...
@@ -112,20 +134,12 @@ if __name__ == "__main__":
if
rank
==
0
:
if
rank
==
0
:
pl
.
plot
([
go
.
Heatmap
(
z
=
sh_data
)],
filename
=
'sh.html'
)
pl
.
plot
([
go
.
Heatmap
(
z
=
sh_data
)],
filename
=
'sh.html'
)
j_data
=
j
.
val
.
get_full_data
().
real
if
rank
==
0
:
pl
.
plot
([
go
.
Heatmap
(
z
=
j_data
)],
filename
=
'j.html'
)
jabs_data
=
np
.
abs
(
j
.
val
.
get_full_data
())
jphase_data
=
np
.
angle
(
j
.
val
.
get_full_data
())
if
rank
==
0
:
pl
.
plot
([
go
.
Heatmap
(
z
=
jabs_data
)],
filename
=
'j_abs.html'
)
pl
.
plot
([
go
.
Heatmap
(
z
=
jphase_data
)],
filename
=
'j_phase.html'
)
m_data
=
m
.
val
.
get_full_data
().
real
m_data
=
m
.
val
.
get_full_data
().
real
if
rank
==
0
:
if
rank
==
0
:
pl
.
plot
([
go
.
Heatmap
(
z
=
m_data
)],
filename
=
'map.html'
)
pl
.
plot
([
go
.
Heatmap
(
z
=
m_data
)],
filename
=
'map.html'
)
# grad_data = grad.val.get_full_data().real
m_wf_data
=
m_wf
.
val
.
get_full_data
().
real
# if rank == 0:
if
rank
==
0
:
# pl.plot([go.Heatmap(z=grad_data)], filename='grad.html')
pl
.
plot
([
go
.
Heatmap
(
z
=
m_wf_data
)],
filename
=
'map_wf.html'
)
demos/wiener_filter.py
→
demos/wiener_filter
_easy
.py
View file @
4cb89177
File moved
demos/wiener_filter_harmonic - Kopie.py
deleted
100644 → 0
View file @
f8383a88
from
nifty
import
*
from
mpi4py
import
MPI
import
plotly.offline
as
py
import
plotly.graph_objs
as
go
comm
=
MPI
.
COMM_WORLD
rank
=
comm
.
rank
def
plot_maps
(
x
,
name
):
trace
=
[
None
]
*
len
(
x
)
keys
=
x
.
keys
()
field
=
x
[
keys
[
0
]]
domain
=
field
.
domain
[
0
]
shape
=
len
(
domain
.
shape
)
max_n
=
domain
.
shape
[
0
]
*
domain
.
distances
[
0
]
step
=
domain
.
distances
[
0
]
x_axis
=
np
.
arange
(
0
,
max_n
,
step
)
if
shape
==
1
:
for
ii
in
xrange
(
len
(
x
)):
trace
[
ii
]
=
go
.
Scatter
(
x
=
x_axis
,
y
=
x
[
keys
[
ii
]].
val
.
get_full_data
(),
name
=
keys
[
ii
])
fig
=
go
.
Figure
(
data
=
trace
)
py
.
plot
(
fig
,
filename
=
name
)
elif
shape
==
2
:
for
ii
in
xrange
(
len
(
x
)):
py
.
plot
([
go
.
Heatmap
(
z
=
x
[
keys
[
ii
]].
val
.
get_full_data
().
real
)],
filename
=
keys
[
ii
])
else
:
raise
TypeError
(
"Only 1D and 2D field plots are supported"
)
def
plot_power
(
x
,
name
):
layout
=
go
.
Layout
(
xaxis
=
dict
(
type
=
'log'
,
autorange
=
True
),
yaxis
=
dict
(
type
=
'log'
,
autorange
=
True
)
)
trace
=
[
None
]
*
len
(
x
)
keys
=
x
.
keys
()
field
=
x
[
keys
[
0
]]
domain
=
field
.
domain
[
0
]
x_axis
=
domain
.
kindex
for
ii
in
xrange
(
len
(
x
)):
trace
[
ii
]
=
go
.
Scatter
(
x
=
x_axis
,
y
=
x
[
keys
[
ii
]].
val
.
get_full_data
(),
name
=
keys
[
ii
])
fig
=
go
.
Figure
(
data
=
trace
,
layout
=
layout
)
py
.
plot
(
fig
,
filename
=
name
)
np
.
random
.
seed
(
42
)
if
__name__
==
"__main__"
:
distribution_strategy
=
'not'
# setting spaces
npix
=
np
.
array
([
500
])
# number of pixels
total_volume
=
1.
# total length
# setting signal parameters
lambda_s
=
.
05
# signal correlation length
sigma_s
=
10.
# signal variance
#setting response operator parameters
length_convolution
=
.
025
exposure
=
1.
# calculating parameters
k_0
=
4.
/
(
2
*
np
.
pi
*
lambda_s
)
a_s
=
sigma_s
**
2.
*
lambda_s
*
total_volume
# creation of spaces
# x1 = RGSpace([npix,npix], distances=total_volume / npix,
# zerocenter=False)
# k1 = RGRGTransformation.get_codomain(x1)
x1
=
HPSpace
(
32
)
k1
=
HPLMTransformation
.
get_codomain
(
x1
)
p1
=
PowerSpace
(
harmonic_partner
=
k1
,
logarithmic
=
False
)
# creating Power Operator with given spectrum
spec
=
(
lambda
k
:
a_s
/
(
1
+
(
k
/
k_0
)
**
2
)
**
2
)
p_field
=
Field
(
p1
,
val
=
spec
)
S_op
=
create_power_operator
(
k1
,
spec
)
# creating FFT-Operator and Response-Operator with Gaussian convolution
Fft_op
=
FFTOperator
(
domain
=
x1
,
target
=
k1
,
domain_dtype
=
np
.
float64
,
target_dtype
=
np
.
complex128
)
R_op
=
ResponseOperator
(
x1
,
sigma
=
[
length_convolution
],
exposure
=
[
exposure
])
# drawing a random field
sk
=
p_field
.
power_synthesize
(
real_signal
=
True
,
mean
=
0.
)
s
=
Fft_op
.
adjoint_times
(
sk
)
signal_to_noise
=
1
N_op
=
DiagonalOperator
(
R_op
.
target
,
diagonal
=
s
.
var
()
/
signal_to_noise
,
bare
=
True
)
n
=
Field
.
from_random
(
domain
=
R_op
.
target
,
random_type
=
'normal'
,
std
=
s
.
std
()
/
np
.
sqrt
(
signal_to_noise
),
mean
=
0.
)
d
=
R_op
(
s
)
+
n
# Wiener filter
j
=
Fft_op
.
times
(
R_op
.
adjoint_times
(
N_op
.
inverse_times
(
d
)))
D
=
HarmonicPropagatorOperator
(
S
=
S_op
,
N
=
N_op
,
R
=
R_op
)
mk
=
D
(
j
)
m
=
Fft_op
.
adjoint_times
(
mk
)
# z={}
# z["signal"] = s
# z["reconstructed_map"] = m
# z["data"] = d
# z["lambda"] = R_op(s)
# z["j"] = j
#
# plot_maps(z, "Wiener_filter.html")
demos/wiener_filter_harmonic.py
deleted
100644 → 0
View file @
f8383a88
from
nifty
import
*
from
mpi4py
import
MPI
import
plotly.offline
as
py
import
plotly.graph_objs
as
go
comm
=
MPI
.
COMM_WORLD
rank
=
comm
.
rank
def
plot_maps
(
x
,
name
):
trace
=
[
None
]
*
len
(
x
)
keys
=
x
.
keys
()
field
=
x
[
keys
[
0
]]
domain
=
field
.
domain
[
0
]
shape
=
len
(
domain
.
shape
)
max_n
=
domain
.
shape
[
0
]
*
domain
.
distances
[
0
]
step
=
domain
.
distances
[
0
]
x_axis
=
np
.
arange
(
0
,
max_n
,
step
)
if
shape
==
1
:
for
ii
in
xrange
(
len
(
x
)):
trace
[
ii
]
=
go
.
Scatter
(
x
=
x_axis
,
y
=
x
[
keys
[
ii
]].
val
.
get_full_data
(),
name
=
keys
[
ii
])
fig
=
go
.
Figure
(
data
=
trace
)
py
.
plot
(
fig
,
filename
=
name
)
elif
shape
==
2
:
for
ii
in
xrange
(
len
(
x
)):
py
.
plot
([
go
.
Heatmap
(
z
=
x
[
keys
[
ii
]].
val
.
get_full_data
())],
filename
=
keys
[
ii
])
else
:
raise
TypeError
(
"Only 1D and 2D field plots are supported"
)
def
plot_power
(
x
,
name
):
layout
=
go
.
Layout
(
xaxis
=
dict
(
type
=
'log'
,
autorange
=
True
),
yaxis
=
dict
(
type
=
'log'
,
autorange
=
True
)
)
trace
=
[
None
]
*
len
(
x
)
keys
=
x
.
keys
()
field
=
x
[
keys
[
0
]]
domain
=
field
.
domain
[
0
]
x_axis
=
domain
.
kindex
for
ii
in
xrange
(
len
(
x
)):
trace
[
ii
]
=
go
.
Scatter
(
x
=
x_axis
,
y
=
x
[
keys
[
ii
]].
val
.
get_full_data
(),
name
=
keys
[
ii
])
fig
=
go
.
Figure
(
data
=
trace
,
layout
=
layout
)
py
.
plot
(
fig
,
filename
=
name
)
np
.
random
.
seed
(
42
)
if
__name__
==
"__main__"
:
distribution_strategy
=
'not'
# setting spaces
npix
=
np
.
array
([
500
])
# number of pixels
total_volume
=
1.
# total length
# setting signal parameters
lambda_s
=
.
05
# signal correlation length
sigma_s
=
10.
# signal variance
#setting response operator parameters
length_convolution
=
.
025
exposure
=
1.
# calculating parameters
k_0
=
4.
/
(
2
*
np
.
pi
*
lambda_s
)
a_s
=
sigma_s
**
2.
*
lambda_s
*
total_volume
# creation of spaces
# x1 = RGSpace([npix,npix], distances=total_volume / npix,
# zerocenter=False)
# k1 = RGRGTransformation.get_codomain(x1)
x1
=
HPSpace
(
64
)
k1
=
HPLMTransformation
.
get_codomain
(
x1
)
p1
=
PowerSpace
(
harmonic_partner
=
k1
,
logarithmic
=
False
)
# creating Power Operator with given spectrum
spec
=
(
lambda
k
:
a_s
/
(
1
+
(
k
/
k_0
)
**
2
)
**
2
)
p_field
=
Field
(
p1
,
val
=
spec
)
S_op
=
create_power_operator
(
k1
,
spec
)
# creating FFT-Operator and Response-Operator with Gaussian convolution
# adjust dtype_target probperly
Fft_op
=
FFTOperator
(
domain
=
x1
,
target
=
k1
,
domain_dtype
=
np
.
float64
,
target_dtype
=
np
.
float64
)
R_op
=
ResponseOperator
(
x1
,
sigma
=
[
length_convolution
],
exposure
=
[
exposure
])
# drawing a random field
sk
=
p_field
.
power_synthesize
(
real_power
=
True
,
mean
=
0.
)
s
=
Fft_op
.
adjoint_times
(
sk
)
signal_to_noise
=
1
N_op
=
DiagonalOperator
(
R_op
.
target
,
diagonal
=
s
.
var
()
/
signal_to_noise
,
bare
=
True
)
n
=
Field
.
from_random
(
domain
=
R_op
.
target
,
random_type
=
'normal'
,
std
=
s
.
std
()
/
np
.
sqrt
(
signal_to_noise
),
mean
=
0
)
d
=
R_op
(
s
)
+
n
# Wiener filter
j
=
Fft_op
.
times
(
R_op
.
adjoint_times
(
N_op
.
inverse_times
(
d
)))
D
=
HarmonicPropagatorOperator
(
S
=
S_op
,
N
=
N_op
,
R
=
R_op
)
mk
=
D
(
j
)
m
=
Fft_op
.
adjoint_times
(
mk
)
# z={}
# z["signal"] = s
# z["reconstructed_map"] = m
# z["data"] = d
# z["lambda"] = R_op(s)
#
# plot_maps(z, "Wiener_filter.html")
demos/wiener_filter_unit.py
deleted
100644 → 0
View file @
f8383a88
from
nifty
import
*
from
mpi4py
import
MPI
import
plotly.offline
as
py
import
plotly.graph_objs
as
go
comm
=
MPI
.
COMM_WORLD
rank
=
comm
.
rank
def
plot_maps
(
x
,
name
):
trace
=
[
None
]
*
len
(
x
)
keys
=
x
.
keys
()
field
=
x
[
keys
[
0
]]
domain
=
field
.
domain
[
0
]
shape
=
len
(
domain
.
shape
)
max_n
=
domain
.
shape
[
0
]
*
domain
.
distances
[
0
]
step
=
domain
.
distances
[
0
]
x_axis
=
np
.
arange
(
0
,
max_n
,
step
)
if
shape
==
1
:
for
ii
in
xrange
(
len
(
x
)):
trace
[
ii
]
=
go
.
Scatter
(
x
=
x_axis
,
y
=
x
[
keys
[
ii
]].
val
.
get_full_data
(),
name
=
keys
[
ii
])
fig
=
go
.
Figure
(
data
=
trace
)
py
.
plot
(
fig
,
filename
=
name
)
elif
shape
==
2
:
for
ii
in
xrange
(
len
(
x
)):
py
.
plot
([
go
.
Heatmap
(
z
=
x
[
keys
[
ii
]].
val
.
get_full_data
())],
filename
=
keys
[
ii
])
else
:
raise
TypeError
(
"Only 1D and 2D field plots are supported"
)
def
plot_power
(
x
,
name
):
layout
=
go
.
Layout
(
xaxis
=
dict
(
type
=
'log'
,
autorange
=
True
),
yaxis
=
dict
(
type
=
'log'
,
autorange
=
True
)
)
trace
=
[
None
]
*
len
(
x
)
keys
=
x
.
keys
()
field
=
x
[
keys
[
0
]]
domain
=
field
.
domain
[
0
]
x_axis
=
domain
.
kindex
for
ii
in
xrange
(
len
(
x
)):
trace
[
ii
]
=
go
.
Scatter
(
x
=
x_axis
,
y
=
x
[
keys
[
ii
]].
val
.
get_full_data
(),
name
=
keys
[
ii
])
fig
=
go
.
Figure
(
data
=
trace
,
layout
=
layout
)
py
.
plot
(
fig
,
filename
=
name
)
np
.
random
.
seed
(
42
)
if
__name__
==
"__main__"
:
distribution_strategy
=
'not'
# setting spaces
npix
=
np
.
array
([
500
])
# number of pixels
total_volume
=
1.
# total length
# setting signal parameters
lambda_s
=
.
05
# signal correlation length
sigma_s
=
10.
# signal variance
#setting response operator parameters
length_convolution
=
.
025
exposure
=
1.
# calculating parameters
k_0
=
4.
/
(
2
*
np
.
pi
*
lambda_s
)
a_s
=
sigma_s
**
2.
*
lambda_s
*
total_volume
# creation of spaces
x1
=
RGSpace
(
npix
,
distances
=
total_volume
/
npix
,
zerocenter
=
False
)
k1
=
RGRGTransformation
.
get_codomain
(
x1
)
p1
=
PowerSpace
(
harmonic_domain
=
k1
,
log
=
False
)
# creating Power Operator with given spectrum
spec
=
(
lambda
k
:
a_s
/
(
1
+
(
k
/
k_0
)
**
2
)
**
2
)
p_field
=
Field
(
p1
,
val
=
spec
)
S_op
=
create_power_operator
(
k1
,
spec
)
# creating FFT-Operator and Response-Operator with Gaussian convolution
Fft_op
=
FFTOperator
(
domain
=
x1
,
target
=
k1
,
domain_dtype
=
np
.
float64
,
target_dtype
=
np
.
complex128
)
R_op
=
ResponseOperator
(
x1
,
sigma
=
[
length_convolution
],
exposure
=
[
exposure
])
# drawing a random field
sk
=
p_field
.
power_synthesize
(
real_signal
=
True
,
mean
=
0.
)
s
=
Fft_op
.
inverse_times
(
sk
)
signal_to_noise
=
1
N_op
=
DiagonalOperator
(
R_op
.
target
,
diagonal
=
s
.
var
()
/
signal_to_noise
,
bare
=
True
)
n
=
Field
.
from_random
(
domain
=
R_op
.
target
,
random_type
=
'normal'
,
std
=
s
.
std
()
/
np
.
sqrt
(
signal_to_noise
),
mean
=
0
)
d
=
R_op
(
s
)
+
n
# Wiener filter
j
=
R_op
.
adjoint_times
(
N_op
.
inverse_times
(
d
))
D
=
PropagatorOperator
(
S
=
S_op
,
N
=
N_op
,
R
=
R_op
)
m
=
D
(
j
)
z
=
{}
z
[
"signal"
]
=
s
z
[
"reconstructed_map"
]
=
m
z
[
"data"
]
=
d
z
[
"lambda"
]
=
R_op
(
s
)
plot_maps
(
z
,
"Wiener_filter.html"
)
nifty/energies/energy.py
View file @
4cb89177
...
@@ -30,7 +30,7 @@ class Energy(Loggable, object):
...
@@ -30,7 +30,7 @@ class Energy(Loggable, object):
position
=
position
.
copy
()
position
=
position
.
copy
()
except
AttributeError
:
except
AttributeError
:
pass
pass
self
.
position
=
position
self
.
_
position
=
position
def
at
(
self
,
position
):
def
at
(
self
,
position
):
return
self
.
__class__
(
position
)
return
self
.
__class__
(
position
)
...
...
nifty/library/energy_library/wiener_filter_energy.py
View file @
4cb89177
...
@@ -20,7 +20,7 @@ class WienerFilterEnergy(Energy):
...
@@ -20,7 +20,7 @@ class WienerFilterEnergy(Energy):
"""
"""
def
__init__
(
self
,
position
,
d
,
R
,
N
,
S
):
def
__init__
(
self
,
position
,
d
,
R
,
N
,
S
):
super
(
WienerFilterEnergy
,
self
).
__init__
(
position
)
super
(
WienerFilterEnergy
,
self
).
__init__
(
position
=
position
)
self
.
d
=
d
self
.
d
=
d
self
.
R
=
R
self
.
R
=
R
self
.
N
=
N
self
.
N
=
N
...
@@ -34,12 +34,13 @@ class WienerFilterEnergy(Energy):
...
@@ -34,12 +34,13 @@ class WienerFilterEnergy(Energy):
energy
=
0.5
*
self
.
position
.
dot
(
self
.
S
.
inverse_times
(
self
.
position
))
energy
=
0.5
*
self
.
position
.
dot
(
self
.
S
.
inverse_times
(
self
.
position
))
energy
+=
0.5
*
(
self
.
d
-
self
.
R
(
self
.
position
)).
dot
(
energy
+=
0.5
*
(
self
.
d
-
self
.
R
(
self
.
position
)).
dot
(