Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Neel Shah
NIFTy
Commits
81d4481b
Commit
81d4481b
authored
Oct 09, 2017
by
Jakob Knollmueller
Browse files
some changes, nothing serious, not working
parent
b400978c
Changes
2
Hide whitespace changes
Inline
Side-by-side
demos/critical_filtering.py
View file @
81d4481b
...
...
@@ -10,20 +10,21 @@ from mpi4py import MPI
comm
=
MPI
.
COMM_WORLD
rank
=
comm
.
rank
np
.
random
.
seed
(
4
2
)
np
.
random
.
seed
(
4
3
)
def
plot_parameters
(
m
,
t
,
p
,
p_d
):
def
plot_parameters
(
m
,
t
,
p
,
p_sig
,
p_d
):
x
=
np
.
log
(
t
.
domain
[
0
].
kindex
)
m
=
fft
.
adjoint_times
(
m
)
m
=
m
.
val
.
get_full_data
().
real
t
=
t
.
val
.
get_full_data
().
real
p
=
p
.
val
.
get_full_data
().
real
pd_sig
=
p_sig
.
val
.
get_full_data
()
p_d
=
p_d
.
val
.
get_full_data
().
real
pl
.
plot
([
go
.
Heatmap
(
z
=
m
)],
filename
=
'map.html'
,
auto_open
=
False
)
pl
.
plot
([
go
.
Scatter
(
x
=
x
,
y
=
t
),
go
.
Scatter
(
x
=
x
,
y
=
p
),
go
.
Scatter
(
x
=
x
,
y
=
p_d
)],
filename
=
"t.html"
,
auto_open
=
False
)
go
.
Scatter
(
x
=
x
,
y
=
p_d
)
,
go
.
Scatter
(
x
=
x
,
y
=
pd_sig
)
],
filename
=
"t.html"
,
auto_open
=
False
)
class
AdjointFFTResponse
(
ift
.
LinearOperator
):
...
...
@@ -58,7 +59,8 @@ if __name__ == "__main__":
distribution_strategy
=
'not'
# Set up position space
s_space
=
ift
.
RGSpace
([
128
,
128
])
dist
=
1
/
128.
*
0.1
s_space
=
ift
.
RGSpace
([
128
,
128
],
distances
=
[
dist
,
dist
])
# s_space = ift.HPSpace(32)
# Define harmonic transformation and associated harmonic space
...
...
@@ -72,7 +74,8 @@ if __name__ == "__main__":
distribution_strategy
=
distribution_strategy
)
# Choose the prior correlation structure and defining correlation operator
p_spec
=
(
lambda
k
:
(.
5
/
(
k
+
1
)
**
3
))
# p_spec = (lambda k: (.5 / (k + 1) ** 3))
p_spec
=
(
lambda
k
:
1
)
S
=
ift
.
create_power_operator
(
h_space
,
power_spectrum
=
p_spec
,
distribution_strategy
=
distribution_strategy
)
...
...
@@ -123,7 +126,6 @@ if __name__ == "__main__":
IC3
=
ift
.
GradientNormController
(
iteration_limit
=
100
,
tol_abs_gradnorm
=
0.1
)
minimizer3
=
ift
.
SteepestDescent
(
IC3
)
# Set starting position
flat_power
=
ift
.
Field
(
p_space
,
val
=
1e-8
)
m0
=
flat_power
.
power_synthesize
(
real_signal
=
True
)
...
...
@@ -137,11 +139,11 @@ if __name__ == "__main__":
# Initialize non-linear Wiener Filter energy
map_energy
=
WienerFilterEnergy
(
position
=
m0
,
d
=
d
,
R
=
R
,
N
=
N
,
S
=
S0
)
# Solve the Wiener Filter analytically
D0
=
map_energy
.
curvature
m0
=
D0
.
inverse_times
(
j
)
#
D0 = map_energy.curvature
#
m0 = D0.inverse_times(j)
# Initialize power energy with updated parameters
power_energy
=
CriticalPowerEnergy
(
position
=
t0
,
m
=
m0
,
D
=
D0
,
smoothness_prior
=
1
0.
,
samples
=
3
)
power_energy
=
CriticalPowerEnergy
(
position
=
t0
,
m
=
sh
,
D
=
None
,
smoothness_prior
=
1
e-15
,
samples
=
3
)
(
power_energy
,
convergence
)
=
minimizer2
(
power_energy
)
...
...
@@ -150,5 +152,6 @@ if __name__ == "__main__":
# Plot current estimate
print
(
i
)
if
i
%
5
==
0
:
plot_parameters
(
m0
,
t0
,
ift
.
log
(
sp
),
data_power
)
if
i
%
1
==
0
:
plot_parameters
(
sh
,
t0
,
ift
.
log
(
sp
),
ift
.
log
(
sh
.
power_analyze
(
binbounds
=
p_space
.
binbounds
)),
data_power
)
print
ift
.
log
(
sh
.
power_analyze
(
binbounds
=
p_space
.
binbounds
)).
val
-
t0
.
val
nifty/library/critical_filter/critical_power_energy.py
View file @
81d4481b
...
...
@@ -73,10 +73,11 @@ class CriticalPowerEnergy(Energy):
self
.
_w
=
w
if
w
is
not
None
else
None
if
inverter
is
None
:
preconditioner
=
DiagonalOperator
(
self
.
_theta
.
domain
,
diagonal
=
self
.
_theta
.
weight
(
-
1
)
,
diagonal
=
self
.
_theta
,
copy
=
False
)
inverter
=
ConjugateGradient
(
preconditioner
=
preconditioner
)
self
.
_inverter
=
inverter
self
.
one
=
Field
(
self
.
position
.
domain
,
val
=
1.
)
@
property
def
inverter
(
self
):
...
...
@@ -94,16 +95,16 @@ class CriticalPowerEnergy(Energy):
@
property
@
memo
def
value
(
self
):
energy
=
self
.
_theta
.
sum
(
)
energy
+=
self
.
position
.
weight
(
-
1
).
vdot
(
self
.
_rho_prime
)
energy
=
self
.
one
.
vdot
(
self
.
_theta
)
energy
+=
self
.
position
.
vdot
(
self
.
one
/
2.
)
energy
+=
0.5
*
self
.
position
.
vdot
(
self
.
_Tt
)
return
energy
.
real
@
property
@
memo
def
gradient
(
self
):
gradient
=
-
self
.
_theta
.
weight
(
-
1
)
gradient
+=
(
self
.
_rho_prime
).
weight
(
-
1
)
gradient
=
-
self
.
_theta
gradient
+=
(
self
.
one
/
2.
)
gradient
+=
self
.
_Tt
gradient
.
val
=
gradient
.
val
.
real
return
gradient
...
...
@@ -111,7 +112,7 @@ class CriticalPowerEnergy(Energy):
@
property
@
memo
def
curvature
(
self
):
return
CriticalPowerCurvature
(
theta
=
self
.
_theta
.
weight
(
-
1
)
,
T
=
self
.
T
,
return
CriticalPowerCurvature
(
theta
=
self
.
_theta
,
T
=
self
.
T
,
inverter
=
self
.
inverter
)
# ---Added properties and methods---
...
...
@@ -142,7 +143,7 @@ class CriticalPowerEnergy(Energy):
w
=
self
.
m
.
power_analyze
(
binbounds
=
self
.
position
.
domain
[
0
].
binbounds
)
w
*=
self
.
rho
self
.
_w
=
w
self
.
_w
=
w
.
weight
(
-
1
)
return
self
.
_w
@
property
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment