Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
N
NIFTy
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
13
Issues
13
List
Boards
Labels
Service Desk
Milestones
Merge Requests
13
Merge Requests
13
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Operations
Operations
Incidents
Environments
Packages & Registries
Packages & Registries
Container Registry
Analytics
Analytics
CI / CD
Repository
Value Stream
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ift
NIFTy
Commits
7b98c3bb
Commit
7b98c3bb
authored
Jul 16, 2017
by
Martin Reinecke
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'line_search' of gitlab.mpcdf.mpg.de:ift/NIFTy into line_search
parents
b5f2473f
41309d1f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
48 additions
and
50 deletions
+48
-50
demos/critical_filtering.py
demos/critical_filtering.py
+48
-50
No files found.
demos/critical_filtering.py
View file @
7b98c3bb
from
nifty
import
*
import
numpy
as
np
from
nifty.library.wiener_filter
import
WienerFilterEnergy
from
nifty
import
(
VL_BFGS
,
DiagonalOperator
,
FFTOperator
,
Field
,
LinearOperator
,
PowerSpace
,
RelaxedNewton
,
RGSpace
,
SteepestDescent
,
create_power_operator
,
exp
,
log
,
sqrt
)
from
nifty.library.critical_filter
import
CriticalPowerEnergy
from
nifty.library.critical_filter
import
CriticalPowerEnergy
import
plotly.offline
as
pl
from
nifty.library.wiener_filter
import
WienerFilterEnergy
import
plotly.graph_objs
as
go
import
plotly.graph_objs
as
go
import
plotly.offline
as
pl
from
mpi4py
import
MPI
from
mpi4py
import
MPI
comm
=
MPI
.
COMM_WORLD
comm
=
MPI
.
COMM_WORLD
rank
=
comm
.
rank
rank
=
comm
.
rank
np
.
random
.
seed
(
42
)
np
.
random
.
seed
(
42
)
def
plot_parameters
(
m
,
t
,
p
,
p_d
):
def
plot_parameters
(
m
,
t
,
p
,
p_d
):
x
=
log
(
t
.
domain
[
0
].
kindex
)
x
=
log
(
t
.
domain
[
0
].
kindex
)
m
=
fft
.
adjoint_times
(
m
)
m
=
fft
.
adjoint_times
(
m
)
...
@@ -20,7 +24,8 @@ def plot_parameters(m,t,p, p_d):
...
@@ -20,7 +24,8 @@ def plot_parameters(m,t,p, p_d):
p
=
p
.
val
.
get_full_data
().
real
p
=
p
.
val
.
get_full_data
().
real
p_d
=
p_d
.
val
.
get_full_data
().
real
p_d
=
p_d
.
val
.
get_full_data
().
real
pl
.
plot
([
go
.
Heatmap
(
z
=
m
)],
filename
=
'map.html'
)
pl
.
plot
([
go
.
Heatmap
(
z
=
m
)],
filename
=
'map.html'
)
pl
.
plot
([
go
.
Scatter
(
x
=
x
,
y
=
t
),
go
.
Scatter
(
x
=
x
,
y
=
p
),
go
.
Scatter
(
x
=
x
,
y
=
p_d
)],
filename
=
"t.html"
)
pl
.
plot
([
go
.
Scatter
(
x
=
x
,
y
=
t
),
go
.
Scatter
(
x
=
x
,
y
=
p
),
go
.
Scatter
(
x
=
x
,
y
=
p_d
)],
filename
=
"t.html"
)
class
AdjointFFTResponse
(
LinearOperator
):
class
AdjointFFTResponse
(
LinearOperator
):
...
@@ -36,6 +41,7 @@ class AdjointFFTResponse(LinearOperator):
...
@@ -36,6 +41,7 @@ class AdjointFFTResponse(LinearOperator):
def
_adjoint_times
(
self
,
x
,
spaces
=
None
):
def
_adjoint_times
(
self
,
x
,
spaces
=
None
):
return
self
.
FFT
(
self
.
R
.
adjoint_times
(
x
))
return
self
.
FFT
(
self
.
R
.
adjoint_times
(
x
))
@
property
@
property
def
domain
(
self
):
def
domain
(
self
):
return
self
.
_domain
return
self
.
_domain
...
@@ -48,41 +54,40 @@ class AdjointFFTResponse(LinearOperator):
...
@@ -48,41 +54,40 @@ class AdjointFFTResponse(LinearOperator):
def
unitary
(
self
):
def
unitary
(
self
):
return
False
return
False
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
distribution_strategy
=
'not'
distribution_strategy
=
'not'
# Set up position space
# Set up position space
s_space
=
RGSpace
([
128
,
128
])
s_space
=
RGSpace
([
128
,
128
])
# s_space = HPSpace(32)
# s_space = HPSpace(32)
# Define harmonic transformation and associated harmonic space
# Define harmonic transformation and associated harmonic space
fft
=
FFTOperator
(
s_space
)
fft
=
FFTOperator
(
s_space
)
h_space
=
fft
.
target
[
0
]
h_space
=
fft
.
target
[
0
]
# Set
ting
up power space
# Set up power space
p_space
=
PowerSpace
(
h_space
,
logarithmic
=
True
,
p_space
=
PowerSpace
(
h_space
,
logarithmic
=
True
,
distribution_strategy
=
distribution_strategy
)
distribution_strategy
=
distribution_strategy
)
# Choos
ing
the prior correlation structure and defining correlation operator
# Choos
e
the prior correlation structure and defining correlation operator
p_spec
=
(
lambda
k
:
(.
5
/
(
k
+
1
)
**
3
))
p_spec
=
(
lambda
k
:
(.
5
/
(
k
+
1
)
**
3
))
S
=
create_power_operator
(
h_space
,
power_spectrum
=
p_spec
,
S
=
create_power_operator
(
h_space
,
power_spectrum
=
p_spec
,
distribution_strategy
=
distribution_strategy
)
distribution_strategy
=
distribution_strategy
)
# Draw
ing
a sample sh from the prior distribution in harmonic space
# Draw a sample sh from the prior distribution in harmonic space
sp
=
Field
(
p_space
,
val
=
p_spec
,
sp
=
Field
(
p_space
,
val
=
p_spec
,
distribution_strategy
=
distribution_strategy
)
distribution_strategy
=
distribution_strategy
)
sh
=
sp
.
power_synthesize
(
real_signal
=
True
)
sh
=
sp
.
power_synthesize
(
real_signal
=
True
)
# Choose the measurement instrument
# Choosing the measurement instrument
# Instrument = SmoothingOperator(s_space, sigma=0.01)
# Instrument = SmoothingOperator(s_space, sigma=0.01)
Instrument
=
DiagonalOperator
(
s_space
,
diagonal
=
1.
)
Instrument
=
DiagonalOperator
(
s_space
,
diagonal
=
1.
)
# Instrument._diagonal.val[200:400, 200:400] = 0
# Instrument._diagonal.val[200:400, 200:400] = 0
#Instrument._diagonal.val[64:512-64, 64:512-64] = 0
#
Instrument._diagonal.val[64:512-64, 64:512-64] = 0
# Add a harmonic transformation to the instrument
#Adding a harmonic transformation to the instrument
R
=
AdjointFFTResponse
(
fft
,
Instrument
)
R
=
AdjointFFTResponse
(
fft
,
Instrument
)
noise
=
1.
noise
=
1.
...
@@ -92,7 +97,7 @@ if __name__ == "__main__":
...
@@ -92,7 +97,7 @@ if __name__ == "__main__":
std
=
sqrt
(
noise
),
std
=
sqrt
(
noise
),
mean
=
0
)
mean
=
0
)
# Creat
ing th
e mock data
# Create mock data
d
=
R
(
sh
)
+
n
d
=
R
(
sh
)
+
n
# The information source
# The information source
...
@@ -103,56 +108,49 @@ if __name__ == "__main__":
...
@@ -103,56 +108,49 @@ if __name__ == "__main__":
if
rank
==
0
:
if
rank
==
0
:
pl
.
plot
([
go
.
Heatmap
(
z
=
d_data
)],
filename
=
'data.html'
)
pl
.
plot
([
go
.
Heatmap
(
z
=
d_data
)],
filename
=
'data.html'
)
# minimization strategy
# Minimization strategy
def
convergence_measure
(
a_energy
,
iteration
):
# returns current energy
def
convergence_measure
(
a_energy
,
iteration
):
# returns current energy
x
=
a_energy
.
value
x
=
a_energy
.
value
print
(
x
,
iteration
)
print
(
x
,
iteration
)
minimizer1
=
RelaxedNewton
(
convergence_tolerance
=
1e-8
,
minimizer1
=
RelaxedNewton
(
convergence_tolerance
=
1e-8
,
convergence_level
=
1
,
convergence_level
=
1
,
iteration_limit
=
5
,
iteration_limit
=
5
,
callback
=
convergence_measure
)
callback
=
convergence_measure
)
minimizer2
=
VL_BFGS
(
convergence_tolerance
=
1e-8
,
minimizer2
=
VL_BFGS
(
convergence_tolerance
=
1e-8
,
convergence_level
=
1
,
convergence_level
=
1
,
iteration_limit
=
1000
,
iteration_limit
=
1000
,
callback
=
convergence_measure
,
callback
=
convergence_measure
,
max_history_length
=
20
)
max_history_length
=
20
)
minimizer3
=
SteepestDescent
(
convergence_tolerance
=
1e-8
,
minimizer3
=
SteepestDescent
(
convergence_tolerance
=
1e-8
,
iteration_limit
=
500
,
iteration_limit
=
500
,
callback
=
convergence_measure
)
callback
=
convergence_measure
)
# Set
ting
starting position
# Set starting position
flat_power
=
Field
(
p_space
,
val
=
1e-8
)
flat_power
=
Field
(
p_space
,
val
=
1e-8
)
m0
=
flat_power
.
power_synthesize
(
real_signal
=
True
)
m0
=
flat_power
.
power_synthesize
(
real_signal
=
True
)
t0
=
Field
(
p_space
,
val
=
log
(
1.
/
(
1
+
p_space
.
kindex
)
**
2
))
t0
=
Field
(
p_space
,
val
=
log
(
1.
/
(
1
+
p_space
.
kindex
)
**
2
))
for
i
in
range
(
50
):
for
i
in
range
(
500
):
S0
=
create_power_operator
(
h_space
,
power_spectrum
=
exp
(
t0
),
S0
=
create_power_operator
(
h_space
,
power_spectrum
=
exp
(
t0
),
distribution_strategy
=
distribution_strategy
)
distribution_strategy
=
distribution_strategy
)
# Initializ
ing the non
linear Wiener Filter energy
# Initializ
e non-
linear Wiener Filter energy
map_energy
=
WienerFilterEnergy
(
position
=
m0
,
d
=
d
,
R
=
R
,
N
=
N
,
S
=
S0
)
map_energy
=
WienerFilterEnergy
(
position
=
m0
,
d
=
d
,
R
=
R
,
N
=
N
,
S
=
S0
)
# Solv
ing
the Wiener Filter analytically
# Solv
e
the Wiener Filter analytically
D0
=
map_energy
.
curvature
D0
=
map_energy
.
curvature
m0
=
D0
.
inverse_times
(
j
)
m0
=
D0
.
inverse_times
(
j
)
# Initializing the power energy with updated parameters
# Initialize power energy with updated parameters
power_energy
=
CriticalPowerEnergy
(
position
=
t0
,
m
=
m0
,
D
=
D0
,
smoothness_prior
=
10.
,
samples
=
3
)
power_energy
=
CriticalPowerEnergy
(
position
=
t0
,
m
=
m0
,
D
=
D0
,
smoothness_prior
=
10.
,
samples
=
3
)
(
power_energy
,
convergence
)
=
minimizer2
(
power_energy
)
(
power_energy
,
convergence
)
=
minimizer2
(
power_energy
)
# Set new power spectrum
t0
.
val
=
power_energy
.
position
.
val
.
real
# Setting new power spectrum
# Plot current estimate
t0
.
val
=
power_energy
.
position
.
val
.
real
print
(
i
)
if
i
%
5
==
0
:
# Plotting current estimate
plot_parameters
(
m0
,
t0
,
log
(
sp
),
data_power
)
print
i
if
i
%
50
==
0
:
plot_parameters
(
m0
,
t0
,
log
(
sp
),
data_power
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment