Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
N
NIFTy
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
10
Issues
10
List
Boards
Labels
Service Desk
Milestones
Merge Requests
8
Merge Requests
8
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Operations
Operations
Incidents
Environments
Packages & Registries
Packages & Registries
Container Registry
Analytics
Analytics
CI / CD
Repository
Value Stream
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ift
NIFTy
Commits
25c0b11c
Commit
25c0b11c
authored
Jul 29, 2017
by
Theo Steininger
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into 'mpitests'
Master See merge request
!174
parents
1a0ebf7e
04af1dae
Pipeline
#15665
failed with stage
in 17 minutes and 2 seconds
Changes
47
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
47 changed files
with
527 additions
and
356 deletions
+527
-356
demos/critical_filtering.py
demos/critical_filtering.py
+53
-51
nifty/energies/line_energy.py
nifty/energies/line_energy.py
+36
-37
nifty/field.py
nifty/field.py
+53
-42
nifty/library/critical_filter/critical_power_curvature.py
nifty/library/critical_filter/critical_power_curvature.py
+3
-2
nifty/library/critical_filter/critical_power_energy.py
nifty/library/critical_filter/critical_power_energy.py
+16
-2
nifty/library/wiener_filter/wiener_filter_curvature.py
nifty/library/wiener_filter/wiener_filter_curvature.py
+3
-2
nifty/library/wiener_filter/wiener_filter_energy.py
nifty/library/wiener_filter/wiener_filter_energy.py
+3
-2
nifty/minimization/descent_minimizer.py
nifty/minimization/descent_minimizer.py
+6
-4
nifty/minimization/line_searching/line_search.py
nifty/minimization/line_searching/line_search.py
+1
-1
nifty/minimization/line_searching/line_search_strong_wolfe.py
...y/minimization/line_searching/line_search_strong_wolfe.py
+46
-50
nifty/operators/composed_operator/composed_operator.py
nifty/operators/composed_operator/composed_operator.py
+4
-3
nifty/operators/diagonal_operator/diagonal_operator.py
nifty/operators/diagonal_operator/diagonal_operator.py
+1
-1
nifty/operators/fft_operator/transformations/rg_transforms.py
...y/operators/fft_operator/transformations/rg_transforms.py
+6
-26
nifty/operators/invertible_operator_mixin/invertible_operator_mixin.py
...rs/invertible_operator_mixin/invertible_operator_mixin.py
+29
-13
nifty/operators/laplace_operator/laplace_operator.py
nifty/operators/laplace_operator/laplace_operator.py
+7
-1
nifty/operators/linear_operator/linear_operator.py
nifty/operators/linear_operator/linear_operator.py
+5
-2
nifty/operators/response_operator/response_operator.py
nifty/operators/response_operator/response_operator.py
+7
-10
nifty/operators/smoothing_operator/fft_smoothing_operator.py
nifty/operators/smoothing_operator/fft_smoothing_operator.py
+1
-1
nifty/operators/smoothness_operator/smoothness_operator.py
nifty/operators/smoothness_operator/smoothness_operator.py
+5
-1
nifty/plotting/descriptors/axis.py
nifty/plotting/descriptors/axis.py
+15
-6
nifty/plotting/figures/figure_2D.py
nifty/plotting/figures/figure_2D.py
+12
-10
nifty/plotting/figures/figure_3D.py
nifty/plotting/figures/figure_3D.py
+10
-2
nifty/plotting/figures/figure_base.py
nifty/plotting/figures/figure_base.py
+1
-1
nifty/plotting/plots/heatmaps/glmollweide.py
nifty/plotting/plots/heatmaps/glmollweide.py
+14
-2
nifty/plotting/plots/heatmaps/heatmap.py
nifty/plotting/plots/heatmaps/heatmap.py
+28
-4
nifty/plotting/plots/heatmaps/hpmollweide.py
nifty/plotting/plots/heatmaps/hpmollweide.py
+14
-2
nifty/plotting/plots/scatter_plots/cartesian_2D.py
nifty/plotting/plots/scatter_plots/cartesian_2D.py
+4
-0
nifty/plotting/plots/scatter_plots/cartesian_3D.py
nifty/plotting/plots/scatter_plots/cartesian_3D.py
+4
-0
nifty/plotting/plots/scatter_plots/geo.py
nifty/plotting/plots/scatter_plots/geo.py
+4
-0
nifty/plotting/plots/scatter_plots/scatter_plot.py
nifty/plotting/plots/scatter_plots/scatter_plot.py
+10
-0
nifty/plotting/plotter/gl_plotter.py
nifty/plotting/plotter/gl_plotter.py
+2
-2
nifty/plotting/plotter/healpix_plotter.py
nifty/plotting/plotter/healpix_plotter.py
+2
-2
nifty/plotting/plotter/plotter_base.py
nifty/plotting/plotter/plotter_base.py
+21
-14
nifty/plotting/plotter/power_plotter.py
nifty/plotting/plotter/power_plotter.py
+3
-3
nifty/plotting/plotter/rg1d_plotter.py
nifty/plotting/plotter/rg1d_plotter.py
+3
-3
nifty/plotting/plotter/rg2d_plotter.py
nifty/plotting/plotter/rg2d_plotter.py
+2
-2
nifty/probing/mixin_classes/diagonal_prober_mixin.py
nifty/probing/mixin_classes/diagonal_prober_mixin.py
+8
-1
nifty/probing/mixin_classes/trace_prober_mixin.py
nifty/probing/mixin_classes/trace_prober_mixin.py
+9
-1
nifty/probing/prober/prober.py
nifty/probing/prober/prober.py
+3
-1
nifty/spaces/power_space/power_space.py
nifty/spaces/power_space/power_space.py
+11
-2
nifty/spaces/rg_space/rg_space.py
nifty/spaces/rg_space/rg_space.py
+20
-17
nifty/spaces/space/space.py
nifty/spaces/space/space.py
+0
-13
nifty/sugar.py
nifty/sugar.py
+38
-7
test/test_field.py
test/test_field.py
+2
-0
test/test_minimization/test_descent_minimizers.py
test/test_minimization/test_descent_minimizers.py
+2
-1
test/test_spaces/test_lm_space.py
test/test_spaces/test_lm_space.py
+0
-4
test/test_spaces/test_rg_space.py
test/test_spaces/test_rg_space.py
+0
-5
No files found.
demos/critical_filtering.py
View file @
25c0b11c
from
nifty
import
*
import
numpy
as
np
from
nifty.library.wiener_filter
import
WienerFilterEnergy
from
nifty
import
(
VL_BFGS
,
DiagonalOperator
,
FFTOperator
,
Field
,
LinearOperator
,
PowerSpace
,
RelaxedNewton
,
RGSpace
,
SteepestDescent
,
create_power_operator
,
exp
,
log
,
sqrt
)
from
nifty.library.critical_filter
import
CriticalPowerEnergy
from
nifty.library.critical_filter
import
CriticalPowerEnergy
import
plotly.offline
as
pl
from
nifty.library.wiener_filter
import
WienerFilterEnergy
import
plotly.graph_objs
as
go
import
plotly.graph_objs
as
go
import
plotly.offline
as
pl
from
mpi4py
import
MPI
from
mpi4py
import
MPI
comm
=
MPI
.
COMM_WORLD
comm
=
MPI
.
COMM_WORLD
rank
=
comm
.
rank
rank
=
comm
.
rank
np
.
random
.
seed
(
42
)
np
.
random
.
seed
(
42
)
def
plot_parameters
(
m
,
t
,
p
,
p_d
):
def
plot_parameters
(
m
,
t
,
p
,
p_d
):
x
=
log
(
t
.
domain
[
0
].
kindex
)
x
=
log
(
t
.
domain
[
0
].
kindex
)
m
=
fft
.
adjoint_times
(
m
)
m
=
fft
.
adjoint_times
(
m
)
...
@@ -20,7 +24,8 @@ def plot_parameters(m,t,p, p_d):
...
@@ -20,7 +24,8 @@ def plot_parameters(m,t,p, p_d):
p
=
p
.
val
.
get_full_data
().
real
p
=
p
.
val
.
get_full_data
().
real
p_d
=
p_d
.
val
.
get_full_data
().
real
p_d
=
p_d
.
val
.
get_full_data
().
real
pl
.
plot
([
go
.
Heatmap
(
z
=
m
)],
filename
=
'map.html'
)
pl
.
plot
([
go
.
Heatmap
(
z
=
m
)],
filename
=
'map.html'
)
pl
.
plot
([
go
.
Scatter
(
x
=
x
,
y
=
t
),
go
.
Scatter
(
x
=
x
,
y
=
p
),
go
.
Scatter
(
x
=
x
,
y
=
p_d
)],
filename
=
"t.html"
)
pl
.
plot
([
go
.
Scatter
(
x
=
x
,
y
=
t
),
go
.
Scatter
(
x
=
x
,
y
=
p
),
go
.
Scatter
(
x
=
x
,
y
=
p_d
)],
filename
=
"t.html"
)
class
AdjointFFTResponse
(
LinearOperator
):
class
AdjointFFTResponse
(
LinearOperator
):
...
@@ -36,6 +41,7 @@ class AdjointFFTResponse(LinearOperator):
...
@@ -36,6 +41,7 @@ class AdjointFFTResponse(LinearOperator):
def
_adjoint_times
(
self
,
x
,
spaces
=
None
):
def
_adjoint_times
(
self
,
x
,
spaces
=
None
):
return
self
.
FFT
(
self
.
R
.
adjoint_times
(
x
))
return
self
.
FFT
(
self
.
R
.
adjoint_times
(
x
))
@
property
@
property
def
domain
(
self
):
def
domain
(
self
):
return
self
.
_domain
return
self
.
_domain
...
@@ -48,41 +54,40 @@ class AdjointFFTResponse(LinearOperator):
...
@@ -48,41 +54,40 @@ class AdjointFFTResponse(LinearOperator):
def
unitary
(
self
):
def
unitary
(
self
):
return
False
return
False
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
distribution_strategy
=
'not'
distribution_strategy
=
'not'
# Set up position space
# Set up position space
s_space
=
RGSpace
([
128
,
128
])
s_space
=
RGSpace
([
128
,
128
])
# s_space = HPSpace(32)
# s_space = HPSpace(32)
# Define harmonic transformation and associated harmonic space
# Define harmonic transformation and associated harmonic space
fft
=
FFTOperator
(
s_space
)
fft
=
FFTOperator
(
s_space
)
h_space
=
fft
.
target
[
0
]
h_space
=
fft
.
target
[
0
]
# Set
ting
up power space
# Set up power space
p_space
=
PowerSpace
(
h_space
,
logarithmic
=
True
,
p_space
=
PowerSpace
(
h_space
,
logarithmic
=
True
,
distribution_strategy
=
distribution_strategy
)
distribution_strategy
=
distribution_strategy
)
# Choos
ing
the prior correlation structure and defining correlation operator
# Choos
e
the prior correlation structure and defining correlation operator
p_spec
=
(
lambda
k
:
(.
5
/
(
k
+
1
)
**
3
))
p_spec
=
(
lambda
k
:
(.
5
/
(
k
+
1
)
**
3
))
S
=
create_power_operator
(
h_space
,
power_spectrum
=
p_spec
,
S
=
create_power_operator
(
h_space
,
power_spectrum
=
p_spec
,
distribution_strategy
=
distribution_strategy
)
distribution_strategy
=
distribution_strategy
)
# Draw
ing
a sample sh from the prior distribution in harmonic space
# Draw a sample sh from the prior distribution in harmonic space
sp
=
Field
(
p_space
,
val
=
p_spec
,
sp
=
Field
(
p_space
,
val
=
p_spec
,
distribution_strategy
=
distribution_strategy
)
distribution_strategy
=
distribution_strategy
)
sh
=
sp
.
power_synthesize
(
real_signal
=
True
)
sh
=
sp
.
power_synthesize
(
real_signal
=
True
)
# Choose the measurement instrument
# Choosing the measurement instrument
# Instrument = SmoothingOperator(s_space, sigma=0.01)
# Instrument = SmoothingOperator(s_space, sigma=0.01)
Instrument
=
DiagonalOperator
(
s_space
,
diagonal
=
1.
)
Instrument
=
DiagonalOperator
(
s_space
,
diagonal
=
1.
)
# Instrument._diagonal.val[200:400, 200:400] = 0
# Instrument._diagonal.val[200:400, 200:400] = 0
#Instrument._diagonal.val[64:512-64, 64:512-64] = 0
# Instrument._diagonal.val[64:512-64, 64:512-64] = 0
#
Adding
a harmonic transformation to the instrument
#
Add
a harmonic transformation to the instrument
R
=
AdjointFFTResponse
(
fft
,
Instrument
)
R
=
AdjointFFTResponse
(
fft
,
Instrument
)
noise
=
1.
noise
=
1.
...
@@ -92,7 +97,7 @@ if __name__ == "__main__":
...
@@ -92,7 +97,7 @@ if __name__ == "__main__":
std
=
sqrt
(
noise
),
std
=
sqrt
(
noise
),
mean
=
0
)
mean
=
0
)
# Creat
ing th
e mock data
# Create mock data
d
=
R
(
sh
)
+
n
d
=
R
(
sh
)
+
n
# The information source
# The information source
...
@@ -103,52 +108,49 @@ if __name__ == "__main__":
...
@@ -103,52 +108,49 @@ if __name__ == "__main__":
if
rank
==
0
:
if
rank
==
0
:
pl
.
plot
([
go
.
Heatmap
(
z
=
d_data
)],
filename
=
'data.html'
)
pl
.
plot
([
go
.
Heatmap
(
z
=
d_data
)],
filename
=
'data.html'
)
# minimization strategy
# Minimization strategy
def
convergence_measure
(
a_energy
,
iteration
):
# returns current energy
def
convergence_measure
(
a_energy
,
iteration
):
# returns current energy
x
=
a_energy
.
value
x
=
a_energy
.
value
print
(
x
,
iteration
)
print
(
x
,
iteration
)
minimizer1
=
RelaxedNewton
(
convergence_tolerance
=
1e-4
,
minimizer1
=
RelaxedNewton
(
convergence_tolerance
=
1e-2
,
convergence_level
=
1
,
convergence_level
=
2
,
iteration_limit
=
5
,
iteration_limit
=
3
,
callback
=
convergence_measure
)
callback
=
convergence_measure
)
minimizer2
=
VL_BFGS
(
convergence_tolerance
=
1e-4
,
convergence_level
=
1
,
minimizer2
=
VL_BFGS
(
convergence_tolerance
=
0
,
iteration_limit
=
20
,
iteration_limit
=
7
,
callback
=
convergence_measure
,
callback
=
convergence_measure
,
max_history_length
=
20
)
max_history_length
=
3
)
minimizer3
=
SteepestDescent
(
convergence_tolerance
=
1e-4
,
iteration_limit
=
100
,
# Setting starting position
callback
=
convergence_measure
)
flat_power
=
Field
(
p_space
,
val
=
1e-8
)
# Set starting position
flat_power
=
Field
(
p_space
,
val
=
1e-8
)
m0
=
flat_power
.
power_synthesize
(
real_signal
=
True
)
m0
=
flat_power
.
power_synthesize
(
real_signal
=
True
)
t0
=
Field
(
p_space
,
val
=
log
(
1.
/
(
1
+
p_space
.
kindex
)
**
2
))
t0
=
Field
(
p_space
,
val
=
log
(
1.
/
(
1
+
p_space
.
kindex
)
**
2
))
for
i
in
range
(
500
):
for
i
in
range
(
500
):
S0
=
create_power_operator
(
h_space
,
power_spectrum
=
exp
(
t0
),
S0
=
create_power_operator
(
h_space
,
power_spectrum
=
exp
(
t0
),
distribution_strategy
=
distribution_strategy
)
distribution_strategy
=
distribution_strategy
)
# Initializ
ing the non
linear Wiener Filter energy
# Initializ
e non-
linear Wiener Filter energy
map_energy
=
WienerFilterEnergy
(
position
=
m0
,
d
=
d
,
R
=
R
,
N
=
N
,
S
=
S0
)
map_energy
=
WienerFilterEnergy
(
position
=
m0
,
d
=
d
,
R
=
R
,
N
=
N
,
S
=
S0
)
# Solv
ing
the Wiener Filter analytically
# Solv
e
the Wiener Filter analytically
D0
=
map_energy
.
curvature
D0
=
map_energy
.
curvature
m0
=
D0
.
inverse_times
(
j
)
m0
=
D0
.
inverse_times
(
j
)
# Initializing the power energy with updated parameters
# Initialize power energy with updated parameters
power_energy
=
CriticalPowerEnergy
(
position
=
t0
,
m
=
m0
,
D
=
D0
,
smoothness_prior
=
10.
,
samples
=
3
)
power_energy
=
CriticalPowerEnergy
(
position
=
t0
,
m
=
m0
,
D
=
D0
,
smoothness_prior
=
10.
,
samples
=
3
)
(
power_energy
,
convergence
)
=
minimizer1
(
power_energy
)
# Setting new power spectrum
t0
.
val
=
power_energy
.
position
.
val
.
real
# Plotting current estimate
(
power_energy
,
convergence
)
=
minimizer2
(
power_energy
)
print
i
if
i
%
50
==
0
:
plot_parameters
(
m0
,
t0
,
log
(
sp
),
data_power
)
# Set new power spectrum
t0
.
val
=
power_energy
.
position
.
val
.
real
# Plot current estimate
print
(
i
)
if
i
%
5
==
0
:
plot_parameters
(
m0
,
t0
,
log
(
sp
),
data_power
)
nifty/energies/line_energy.py
View file @
25c0b11c
...
@@ -16,10 +16,8 @@
...
@@ -16,10 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
# and financially supported by the Studienstiftung des deutschen Volkes.
from
.energy
import
Energy
class
LineEnergy
(
object
):
class
LineEnergy
(
Energy
):
""" Evaluates an underlying Energy along a certain line direction.
""" Evaluates an underlying Energy along a certain line direction.
Given an Energy class and a line direction, its position is parametrized by
Given an Energy class and a line direction, its position is parametrized by
...
@@ -27,34 +25,31 @@ class LineEnergy(Energy):
...
@@ -27,34 +25,31 @@ class LineEnergy(Energy):
Parameters
Parameters
----------
----------
position : float
line_position : float
The step length parameter along the given line direction.
Defines the full spatial position of this energy via
self.energy.position = zero_point + line_position*line_direction
energy : Energy
energy : Energy
The Energy object which will be evaluated along the given direction.
The Energy object which will be evaluated along the given direction.
line_direction : Field
line_direction : Field
Direction used for line evaluation.
Direction used for line evaluation.
Does not have to be normalized.
zero_point : Field
*optional*
offset : float
*optional*
Fixing the zero point of the line restriction. Used to memorize this
Indirectly defines the zero point of the line via the equation
position in new initializations. By the default the current posi
tion
energy.position = zero_point + offset*line_direc
tion
of the supplied `energy` instance is used (default : None
).
(default : 0.
).
Attributes
Attributes
----------
----------
position : float
line_
position : float
The position along the given line direction relative to the zero point.
The position along the given line direction relative to the zero point.
value : float
value : float
The value of the energy functional at given `position`.
The value of the energy functional at the given position
gradient : float
directional_derivative : float
The gradient of the underlying energy instance along the line direction
The directional derivative at the given position
projected on the line direction.
curvature : callable
A positive semi-definite operator or function describing the curvature
of the potential at given `position`.
line_direction : Field
line_direction : Field
Direction along which the movement is restricted. Does not have to be
Direction along which the movement is restricted. Does not have to be
normalized.
normalized.
energy : Energy
energy : Energy
The underlying Energy at the
`position` along the line direction.
The underlying Energy at the
given position
Raises
Raises
------
------
...
@@ -72,45 +67,49 @@ class LineEnergy(Energy):
...
@@ -72,45 +67,49 @@ class LineEnergy(Energy):
"""
"""
def
__init__
(
self
,
position
,
energy
,
line_direction
,
zero_point
=
None
):
def
__init__
(
self
,
line_position
,
energy
,
line_direction
,
offset
=
0.
):
super
(
LineEnergy
,
self
).
__init__
(
position
=
position
)
self
.
_line_position
=
float
(
line_position
)
self
.
line_direction
=
line_direction
self
.
_line_direction
=
line_direction
if
zero_point
is
None
:
zero_point
=
energy
.
position
self
.
_zero_point
=
zero_point
position_on_line
=
self
.
_zero_point
+
self
.
position
*
line_direction
pos
=
energy
.
position
\
self
.
energy
=
energy
.
at
(
position
=
position_on_line
)
+
(
self
.
_line_position
-
float
(
offset
))
*
self
.
_line_direction
self
.
energy
=
energy
.
at
(
position
=
pos
)
def
at
(
self
,
position
):
def
at
(
self
,
line_
position
):
""" Returns LineEnergy at new position, memorizing the zero point.
""" Returns LineEnergy at new position, memorizing the zero point.
Parameters
Parameters
----------
----------
position : float
line_
position : float
Parameter for the new position on the line direction.
Parameter for the new position on the line direction.
Returns
Returns
-------
-------
out : LineEnergy
LineEnergy object at new position with same zero point as `self`.
LineEnergy object at new position with same zero point as `self`.
"""
"""
return
self
.
__class__
(
position
,
return
self
.
__class__
(
line_
position
,
self
.
energy
,
self
.
energy
,
self
.
line_direction
,
self
.
line_direction
,
zero_point
=
self
.
_zero_point
)
offset
=
self
.
line_position
)
@
property
@
property
def
value
(
self
):
def
value
(
self
):
return
self
.
energy
.
value
return
self
.
energy
.
value
@
property
@
property
def
gradient
(
self
):
def
line_position
(
self
):
return
self
.
energy
.
gradient
.
vdot
(
self
.
line_direction
)
return
self
.
_line_position
@
property
def
line_direction
(
self
):
return
self
.
_line_direction
@
property
@
property
def
curvature
(
self
):
def
directional_derivative
(
self
):
return
self
.
energy
.
curvature
res
=
self
.
energy
.
gradient
.
vdot
(
self
.
line_direction
)
if
abs
(
res
.
imag
)
/
max
(
abs
(
res
.
real
),
1.
)
>
1e-12
:
print
"directional derivative has non-negligible "
\
"imaginary part:"
,
res
return
res
.
real
nifty/field.py
View file @
25c0b11c
...
@@ -112,7 +112,6 @@ class Field(Loggable, Versionable, object):
...
@@ -112,7 +112,6 @@ class Field(Loggable, Versionable, object):
def
__init__
(
self
,
domain
=
None
,
val
=
None
,
dtype
=
None
,
def
__init__
(
self
,
domain
=
None
,
val
=
None
,
dtype
=
None
,
distribution_strategy
=
None
,
copy
=
False
):
distribution_strategy
=
None
,
copy
=
False
):
self
.
domain
=
self
.
_parse_domain
(
domain
=
domain
,
val
=
val
)
self
.
domain
=
self
.
_parse_domain
(
domain
=
domain
,
val
=
val
)
self
.
domain_axes
=
self
.
_get_axes_tuple
(
self
.
domain
)
self
.
domain_axes
=
self
.
_get_axes_tuple
(
self
.
domain
)
...
@@ -128,6 +127,7 @@ class Field(Loggable, Versionable, object):
...
@@ -128,6 +127,7 @@ class Field(Loggable, Versionable, object):
else
:
else
:
self
.
set_val
(
new_val
=
val
,
copy
=
copy
)
self
.
set_val
(
new_val
=
val
,
copy
=
copy
)
def
_parse_domain
(
self
,
domain
,
val
=
None
):
def
_parse_domain
(
self
,
domain
,
val
=
None
):
if
domain
is
None
:
if
domain
is
None
:
if
isinstance
(
val
,
Field
):
if
isinstance
(
val
,
Field
):
...
@@ -466,7 +466,7 @@ class Field(Loggable, Versionable, object):
...
@@ -466,7 +466,7 @@ class Field(Loggable, Versionable, object):
return
result_obj
return
result_obj
def
power_synthesize
(
self
,
spaces
=
None
,
real_power
=
True
,
real_signal
=
True
,
def
power_synthesize
(
self
,
spaces
=
None
,
real_power
=
True
,
real_signal
=
True
,
mean
=
None
,
std
=
None
):
mean
=
None
,
std
=
None
,
distribution_strategy
=
None
):
""" Yields a sampled field with `self`**2 as its power spectrum.
""" Yields a sampled field with `self`**2 as its power spectrum.
This method draws a Gaussian random field in the harmonic partner
This method draws a Gaussian random field in the harmonic partner
...
@@ -541,13 +541,16 @@ class Field(Loggable, Versionable, object):
...
@@ -541,13 +541,16 @@ class Field(Loggable, Versionable, object):
else
:
else
:
result_list
=
[
None
,
None
]
result_list
=
[
None
,
None
]
if
distribution_strategy
is
None
:
distribution_strategy
=
gc
[
'default_distribution_strategy'
]
result_list
=
[
self
.
__class__
.
from_random
(
result_list
=
[
self
.
__class__
.
from_random
(
'normal'
,
'normal'
,
mean
=
mean
,
mean
=
mean
,
std
=
std
,
std
=
std
,
domain
=
result_domain
,
domain
=
result_domain
,
dtype
=
np
.
complex
,
dtype
=
np
.
complex
,
distribution_strategy
=
self
.
distribution_strategy
)
distribution_strategy
=
distribution_strategy
)
for
x
in
result_list
]
for
x
in
result_list
]
# from now on extract the values from the random fields for further
# from now on extract the values from the random fields for further
...
@@ -609,39 +612,47 @@ class Field(Loggable, Versionable, object):
...
@@ -609,39 +612,47 @@ class Field(Loggable, Versionable, object):
# correct variance
# correct variance
if
preserve_gaussian_variance
:
if
preserve_gaussian_variance
:
assert
issubclass
(
val
.
dtype
.
type
,
np
.
complexfloating
),
\
"complex input field is needed here"
h
*=
np
.
sqrt
(
2
)
h
*=
np
.
sqrt
(
2
)
a
*=
np
.
sqrt
(
2
)
a
*=
np
.
sqrt
(
2
)
if
not
issubclass
(
val
.
dtype
.
type
,
np
.
complexfloating
):
# The code below should not be needed in practice, since it would
# in principle one must not correct the variance for the fixed
# only ever be called when hermitianizing a purely real field.
# points of the hermitianization. However, for a complex field
# However it might be of educational use and keep us from forgetting
# the input field loses half of its power at its fixed points
# how these things are done ...
# in the `hermitian` part. Hence, here a factor of sqrt(2) is
# also necessary!
# if not issubclass(val.dtype.type, np.complexfloating):
# => The hermitianization can be done on a space level since
# # in principle one must not correct the variance for the fixed
# either nothing must be done (LMSpace) or ALL points need a
# # points of the hermitianization. However, for a complex field
# factor of sqrt(2)
# # the input field loses half of its power at its fixed points
# => use the preserve_gaussian_variance flag in the
# # in the `hermitian` part. Hence, here a factor of sqrt(2) is
# hermitian_decomposition method above.
# # also necessary!
# # => The hermitianization can be done on a space level since
# This code is for educational purposes:
# # either nothing must be done (LMSpace) or ALL points need a
fixed_points
=
[
domain
[
i
].
hermitian_fixed_points
()
# # factor of sqrt(2)
for
i
in
spaces
]
# # => use the preserve_gaussian_variance flag in the
fixed_points
=
[[
fp
]
if
fp
is
None
else
fp
# # hermitian_decomposition method above.
for
fp
in
fixed_points
]
#
# # This code is for educational purposes:
for
product_point
in
itertools
.
product
(
*
fixed_points
):
# fixed_points = [domain[i].hermitian_fixed_points()
slice_object
=
np
.
array
((
slice
(
None
),
)
*
len
(
val
.
shape
),
# for i in spaces]
dtype
=
np
.
object
)
# fixed_points = [[fp] if fp is None else fp
for
i
,
sp
in
enumerate
(
spaces
):
# for fp in fixed_points]
point_component
=
product_point
[
i
]
#
if
point_component
is
None
:
# for product_point in itertools.product(*fixed_points):
point_component
=
slice
(
None
)
# slice_object = np.array((slice(None), )*len(val.shape),
slice_object
[
list
(
domain_axes
[
sp
])]
=
point_component
# dtype=np.object)
# for i, sp in enumerate(spaces):
slice_object
=
tuple
(
slice_object
)
# point_component = product_point[i]
h
[
slice_object
]
/=
np
.
sqrt
(
2
)
# if point_component is None:
a
[
slice_object
]
/=
np
.
sqrt
(
2
)
# point_component = slice(None)
# slice_object[list(domain_axes[sp])] = point_component
#
# slice_object = tuple(slice_object)
# h[slice_object] /= np.sqrt(2)
# a[slice_object] /= np.sqrt(2)
return
(
h
,
a
)
return
(
h
,
a
)
def
_spec_to_rescaler
(
self
,
spec
,
result_list
,
power_space_index
):
def
_spec_to_rescaler
(
self
,
spec
,
result_list
,
power_space_index
):
...
@@ -657,7 +668,7 @@ class Field(Loggable, Versionable, object):
...
@@ -657,7 +668,7 @@ class Field(Loggable, Versionable, object):
result_list
[
0
].
domain_axes
[
power_space_index
])
result_list
[
0
].
domain_axes
[
power_space_index
])
if
pindex
.
distribution_strategy
is
not
local_distribution_strategy
:
if
pindex
.
distribution_strategy
is
not
local_distribution_strategy
:
self
.
logger
.
warn
(
raise
AttributeError
(
"The distribution_strategy of pindex does not fit the "
"The distribution_strategy of pindex does not fit the "
"slice_local distribution strategy of the synthesized field."
)
"slice_local distribution strategy of the synthesized field."
)
...
@@ -764,14 +775,14 @@ class Field(Loggable, Versionable, object):
...
@@ -764,14 +775,14 @@ class Field(Loggable, Versionable, object):
dim
dim
"""
"""
if
not
hasattr
(
self
,
'_shape'
):
shape_tuple
=
tuple
(
sp
.
shape
for
sp
in
self
.
domain
)
shape_tuple
=
tuple
(
sp
.
shape
for
sp
in
self
.
domain
)
try
:
try
:
global_shape
=
reduce
(
lambda
x
,
y
:
x
+
y
,
shape_tuple
)
global_shape
=
reduce
(
lambda
x
,
y
:
x
+
y
,
shape_tuple
)
except
TypeError
:
except
TypeError
:
global_shape
=
()
global_shape
=
()
self
.
_shape
=
global_shape
return
global
_shape
return
self
.
_shape
@
property
@
property
def
dim
(
self
):
def
dim
(
self
):
...
...
nifty/library/critical_filter/critical_power_curvature.py
View file @
25c0b11c
...
@@ -21,7 +21,7 @@ class CriticalPowerCurvature(InvertibleOperatorMixin, EndomorphicOperator):
...
@@ -21,7 +21,7 @@ class CriticalPowerCurvature(InvertibleOperatorMixin, EndomorphicOperator):
# ---Overwritten properties and methods---
# ---Overwritten properties and methods---
def
__init__
(
self
,
theta
,
T
,
inverter
=
None
,
preconditioner
=
None
):
def
__init__
(
self
,
theta
,
T
,
inverter
=
None
,
preconditioner
=
None
,
**
kwargs
):
self
.
theta
=
DiagonalOperator
(
theta
.
domain
,
diagonal
=
theta
)