Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
N
NIFTy
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
10
Issues
10
List
Boards
Labels
Service Desk
Milestones
Merge Requests
9
Merge Requests
9
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Operations
Operations
Incidents
Environments
Packages & Registries
Packages & Registries
Container Registry
Analytics
Analytics
CI / CD
Repository
Value Stream
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ift
NIFTy
Commits
25c0b11c
Commit
25c0b11c
authored
Jul 29, 2017
by
Theo Steininger
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into 'mpitests'
Master See merge request
!174
parents
1a0ebf7e
04af1dae
Pipeline
#15665
failed with stage
in 17 minutes and 2 seconds
Changes
47
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
47 changed files
with
527 additions
and
356 deletions
+527
-356
demos/critical_filtering.py
demos/critical_filtering.py
+53
-51
nifty/energies/line_energy.py
nifty/energies/line_energy.py
+36
-37
nifty/field.py
nifty/field.py
+53
-42
nifty/library/critical_filter/critical_power_curvature.py
nifty/library/critical_filter/critical_power_curvature.py
+3
-2
nifty/library/critical_filter/critical_power_energy.py
nifty/library/critical_filter/critical_power_energy.py
+16
-2
nifty/library/wiener_filter/wiener_filter_curvature.py
nifty/library/wiener_filter/wiener_filter_curvature.py
+3
-2
nifty/library/wiener_filter/wiener_filter_energy.py
nifty/library/wiener_filter/wiener_filter_energy.py
+3
-2
nifty/minimization/descent_minimizer.py
nifty/minimization/descent_minimizer.py
+6
-4
nifty/minimization/line_searching/line_search.py
nifty/minimization/line_searching/line_search.py
+1
-1
nifty/minimization/line_searching/line_search_strong_wolfe.py
...y/minimization/line_searching/line_search_strong_wolfe.py
+46
-50
nifty/operators/composed_operator/composed_operator.py
nifty/operators/composed_operator/composed_operator.py
+4
-3
nifty/operators/diagonal_operator/diagonal_operator.py
nifty/operators/diagonal_operator/diagonal_operator.py
+1
-1
nifty/operators/fft_operator/transformations/rg_transforms.py
...y/operators/fft_operator/transformations/rg_transforms.py
+6
-26
nifty/operators/invertible_operator_mixin/invertible_operator_mixin.py
...rs/invertible_operator_mixin/invertible_operator_mixin.py
+29
-13
nifty/operators/laplace_operator/laplace_operator.py
nifty/operators/laplace_operator/laplace_operator.py
+7
-1
nifty/operators/linear_operator/linear_operator.py
nifty/operators/linear_operator/linear_operator.py
+5
-2
nifty/operators/response_operator/response_operator.py
nifty/operators/response_operator/response_operator.py
+7
-10
nifty/operators/smoothing_operator/fft_smoothing_operator.py
nifty/operators/smoothing_operator/fft_smoothing_operator.py
+1
-1
nifty/operators/smoothness_operator/smoothness_operator.py
nifty/operators/smoothness_operator/smoothness_operator.py
+5
-1
nifty/plotting/descriptors/axis.py
nifty/plotting/descriptors/axis.py
+15
-6
nifty/plotting/figures/figure_2D.py
nifty/plotting/figures/figure_2D.py
+12
-10
nifty/plotting/figures/figure_3D.py
nifty/plotting/figures/figure_3D.py
+10
-2
nifty/plotting/figures/figure_base.py
nifty/plotting/figures/figure_base.py
+1
-1
nifty/plotting/plots/heatmaps/glmollweide.py
nifty/plotting/plots/heatmaps/glmollweide.py
+14
-2
nifty/plotting/plots/heatmaps/heatmap.py
nifty/plotting/plots/heatmaps/heatmap.py
+28
-4
nifty/plotting/plots/heatmaps/hpmollweide.py
nifty/plotting/plots/heatmaps/hpmollweide.py
+14
-2
nifty/plotting/plots/scatter_plots/cartesian_2D.py
nifty/plotting/plots/scatter_plots/cartesian_2D.py
+4
-0
nifty/plotting/plots/scatter_plots/cartesian_3D.py
nifty/plotting/plots/scatter_plots/cartesian_3D.py
+4
-0
nifty/plotting/plots/scatter_plots/geo.py
nifty/plotting/plots/scatter_plots/geo.py
+4
-0
nifty/plotting/plots/scatter_plots/scatter_plot.py
nifty/plotting/plots/scatter_plots/scatter_plot.py
+10
-0
nifty/plotting/plotter/gl_plotter.py
nifty/plotting/plotter/gl_plotter.py
+2
-2
nifty/plotting/plotter/healpix_plotter.py
nifty/plotting/plotter/healpix_plotter.py
+2
-2
nifty/plotting/plotter/plotter_base.py
nifty/plotting/plotter/plotter_base.py
+21
-14
nifty/plotting/plotter/power_plotter.py
nifty/plotting/plotter/power_plotter.py
+3
-3
nifty/plotting/plotter/rg1d_plotter.py
nifty/plotting/plotter/rg1d_plotter.py
+3
-3
nifty/plotting/plotter/rg2d_plotter.py
nifty/plotting/plotter/rg2d_plotter.py
+2
-2
nifty/probing/mixin_classes/diagonal_prober_mixin.py
nifty/probing/mixin_classes/diagonal_prober_mixin.py
+8
-1
nifty/probing/mixin_classes/trace_prober_mixin.py
nifty/probing/mixin_classes/trace_prober_mixin.py
+9
-1
nifty/probing/prober/prober.py
nifty/probing/prober/prober.py
+3
-1
nifty/spaces/power_space/power_space.py
nifty/spaces/power_space/power_space.py
+11
-2
nifty/spaces/rg_space/rg_space.py
nifty/spaces/rg_space/rg_space.py
+20
-17
nifty/spaces/space/space.py
nifty/spaces/space/space.py
+0
-13
nifty/sugar.py
nifty/sugar.py
+38
-7
test/test_field.py
test/test_field.py
+2
-0
test/test_minimization/test_descent_minimizers.py
test/test_minimization/test_descent_minimizers.py
+2
-1
test/test_spaces/test_lm_space.py
test/test_spaces/test_lm_space.py
+0
-4
test/test_spaces/test_rg_space.py
test/test_spaces/test_rg_space.py
+0
-5
No files found.
demos/critical_filtering.py
View file @
25c0b11c
from
nifty
import
*
from
nifty.library.wiener_filter
import
WienerFilterEnergy
import
numpy
as
np
from
nifty
import
(
VL_BFGS
,
DiagonalOperator
,
FFTOperator
,
Field
,
LinearOperator
,
PowerSpace
,
RelaxedNewton
,
RGSpace
,
SteepestDescent
,
create_power_operator
,
exp
,
log
,
sqrt
)
from
nifty.library.critical_filter
import
CriticalPowerEnergy
import
plotly.offline
as
pl
import
plotly.graph_objs
as
go
from
nifty.library.wiener_filter
import
WienerFilterEnergy
import
plotly.graph_objs
as
go
import
plotly.offline
as
pl
from
mpi4py
import
MPI
comm
=
MPI
.
COMM_WORLD
rank
=
comm
.
rank
np
.
random
.
seed
(
42
)
def
plot_parameters
(
m
,
t
,
p
,
p_d
):
def
plot_parameters
(
m
,
t
,
p
,
p_d
):
x
=
log
(
t
.
domain
[
0
].
kindex
)
m
=
fft
.
adjoint_times
(
m
)
...
...
@@ -20,7 +24,8 @@ def plot_parameters(m,t,p, p_d):
p
=
p
.
val
.
get_full_data
().
real
p_d
=
p_d
.
val
.
get_full_data
().
real
pl
.
plot
([
go
.
Heatmap
(
z
=
m
)],
filename
=
'map.html'
)
pl
.
plot
([
go
.
Scatter
(
x
=
x
,
y
=
t
),
go
.
Scatter
(
x
=
x
,
y
=
p
),
go
.
Scatter
(
x
=
x
,
y
=
p_d
)],
filename
=
"t.html"
)
pl
.
plot
([
go
.
Scatter
(
x
=
x
,
y
=
t
),
go
.
Scatter
(
x
=
x
,
y
=
p
),
go
.
Scatter
(
x
=
x
,
y
=
p_d
)],
filename
=
"t.html"
)
class
AdjointFFTResponse
(
LinearOperator
):
...
...
@@ -36,6 +41,7 @@ class AdjointFFTResponse(LinearOperator):
def
_adjoint_times
(
self
,
x
,
spaces
=
None
):
return
self
.
FFT
(
self
.
R
.
adjoint_times
(
x
))
@
property
def
domain
(
self
):
return
self
.
_domain
...
...
@@ -48,41 +54,40 @@ class AdjointFFTResponse(LinearOperator):
def
unitary
(
self
):
return
False
if
__name__
==
"__main__"
:
distribution_strategy
=
'not'
# Set up position space
s_space
=
RGSpace
([
128
,
128
])
s_space
=
RGSpace
([
128
,
128
])
# s_space = HPSpace(32)
# Define harmonic transformation and associated harmonic space
fft
=
FFTOperator
(
s_space
)
h_space
=
fft
.
target
[
0
]
# Set
ting
up power space
# Set up power space
p_space
=
PowerSpace
(
h_space
,
logarithmic
=
True
,
distribution_strategy
=
distribution_strategy
)
# Choos
ing
the prior correlation structure and defining correlation operator
# Choos
e
the prior correlation structure and defining correlation operator
p_spec
=
(
lambda
k
:
(.
5
/
(
k
+
1
)
**
3
))
S
=
create_power_operator
(
h_space
,
power_spectrum
=
p_spec
,
distribution_strategy
=
distribution_strategy
)
# Draw
ing
a sample sh from the prior distribution in harmonic space
# Draw a sample sh from the prior distribution in harmonic space
sp
=
Field
(
p_space
,
val
=
p_spec
,
distribution_strategy
=
distribution_strategy
)
sh
=
sp
.
power_synthesize
(
real_signal
=
True
)
# Choosing the measurement instrument
# Choose the measurement instrument
# Instrument = SmoothingOperator(s_space, sigma=0.01)
Instrument
=
DiagonalOperator
(
s_space
,
diagonal
=
1.
)
# Instrument._diagonal.val[200:400, 200:400] = 0
#Instrument._diagonal.val[64:512-64, 64:512-64] = 0
#
Instrument._diagonal.val[64:512-64, 64:512-64] = 0
#Adding a harmonic transformation to the instrument
# Add a harmonic transformation to the instrument
R
=
AdjointFFTResponse
(
fft
,
Instrument
)
noise
=
1.
...
...
@@ -92,7 +97,7 @@ if __name__ == "__main__":
std
=
sqrt
(
noise
),
mean
=
0
)
# Creat
ing th
e mock data
# Create mock data
d
=
R
(
sh
)
+
n
# The information source
...
...
@@ -103,52 +108,49 @@ if __name__ == "__main__":
if
rank
==
0
:
pl
.
plot
([
go
.
Heatmap
(
z
=
d_data
)],
filename
=
'data.html'
)
# minimization strategy
# Minimization strategy
def
convergence_measure
(
a_energy
,
iteration
):
# returns current energy
x
=
a_energy
.
value
print
(
x
,
iteration
)
print
(
x
,
iteration
)
minimizer1
=
RelaxedNewton
(
convergence_tolerance
=
1e-
2
,
convergence_level
=
2
,
iteration_limit
=
3
,
minimizer1
=
RelaxedNewton
(
convergence_tolerance
=
1e-
4
,
convergence_level
=
1
,
iteration_limit
=
5
,
callback
=
convergence_measure
)
minimizer2
=
VL_BFGS
(
convergence_tolerance
=
0
,
iteration_limit
=
7
,
minimizer2
=
VL_BFGS
(
convergence_tolerance
=
1e-4
,
convergence_level
=
1
,
iteration_limit
=
20
,
callback
=
convergence_measure
,
max_history_length
=
3
)
max_history_length
=
20
)
minimizer3
=
SteepestDescent
(
convergence_tolerance
=
1e-4
,
iteration_limit
=
100
,
callback
=
convergence_measure
)
# Set
ting
starting position
flat_power
=
Field
(
p_space
,
val
=
1e-8
)
# Set starting position
flat_power
=
Field
(
p_space
,
val
=
1e-8
)
m0
=
flat_power
.
power_synthesize
(
real_signal
=
True
)
t0
=
Field
(
p_space
,
val
=
log
(
1.
/
(
1
+
p_space
.
kindex
)
**
2
))
for
i
in
range
(
500
):
S0
=
create_power_operator
(
h_space
,
power_spectrum
=
exp
(
t0
),
distribution_strategy
=
distribution_strategy
)
# Initializ
ing the non
linear Wiener Filter energy
# Initializ
e non-
linear Wiener Filter energy
map_energy
=
WienerFilterEnergy
(
position
=
m0
,
d
=
d
,
R
=
R
,
N
=
N
,
S
=
S0
)
# Solv
ing
the Wiener Filter analytically
# Solv
e
the Wiener Filter analytically
D0
=
map_energy
.
curvature
m0
=
D0
.
inverse_times
(
j
)
# Initializing the power energy with updated parameters
power_energy
=
CriticalPowerEnergy
(
position
=
t0
,
m
=
m0
,
D
=
D0
,
smoothness_prior
=
10.
,
samples
=
3
)
(
power_energy
,
convergence
)
=
minimizer1
(
power_energy
)
# Initialize power energy with updated parameters
power_energy
=
CriticalPowerEnergy
(
position
=
t0
,
m
=
m0
,
D
=
D0
,
smoothness_prior
=
10.
,
samples
=
3
)
(
power_energy
,
convergence
)
=
minimizer2
(
power_energy
)
# Set
ting
new power spectrum
# Set new power spectrum
t0
.
val
=
power_energy
.
position
.
val
.
real
# Plotting current estimate
print
i
if
i
%
50
==
0
:
plot_parameters
(
m0
,
t0
,
log
(
sp
),
data_power
)
# Plot current estimate
print
(
i
)
if
i
%
5
==
0
:
plot_parameters
(
m0
,
t0
,
log
(
sp
),
data_power
)
nifty/energies/line_energy.py
View file @
25c0b11c
...
...
@@ -16,10 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from
.energy
import
Energy
class
LineEnergy
(
Energy
):
class
LineEnergy
(
object
):
""" Evaluates an underlying Energy along a certain line direction.
Given an Energy class and a line direction, its position is parametrized by
...
...
@@ -27,34 +25,31 @@ class LineEnergy(Energy):
Parameters
----------
position : float
The step length parameter along the given line direction.
line_position : float
Defines the full spatial position of this energy via
self.energy.position = zero_point + line_position*line_direction
energy : Energy
The Energy object which will be evaluated along the given direction.
line_direction : Field
Direction used for line evaluation.
zero_point : Field
*optional*
Fixing the zero point of the line restriction. Used to memorize this
position in new initializations. By the default the current posi
tion
of the supplied `energy` instance is used (default : None
).
Direction used for line evaluation.
Does not have to be normalized.
offset : float
*optional*
Indirectly defines the zero point of the line via the equation
energy.position = zero_point + offset*line_direc
tion
(default : 0.
).
Attributes
----------
position : float
line_
position : float
The position along the given line direction relative to the zero point.
value : float
The value of the energy functional at given `position`.
gradient : float
The gradient of the underlying energy instance along the line direction
projected on the line direction.
curvature : callable
A positive semi-definite operator or function describing the curvature
of the potential at given `position`.
The value of the energy functional at the given position
directional_derivative : float
The directional derivative at the given position
line_direction : Field
Direction along which the movement is restricted. Does not have to be
normalized.
energy : Energy
The underlying Energy at the
`position` along the line direction.
The underlying Energy at the
given position
Raises
------
...
...
@@ -72,45 +67,49 @@ class LineEnergy(Energy):
"""
def
__init__
(
self
,
position
,
energy
,
line_direction
,
zero_point
=
None
):
super
(
LineEnergy
,
self
).
__init__
(
position
=
position
)
self
.
line_direction
=
line_direction
if
zero_point
is
None
:
zero_point
=
energy
.
position
self
.
_zero_point
=
zero_point
def
__init__
(
self
,
line_position
,
energy
,
line_direction
,
offset
=
0.
):
self
.
_line_position
=
float
(
line_position
)
self
.
_line_direction
=
line_direction
position_on_line
=
self
.
_zero_point
+
self
.
position
*
line_direction
self
.
energy
=
energy
.
at
(
position
=
position_on_line
)
pos
=
energy
.
position
\
+
(
self
.
_line_position
-
float
(
offset
))
*
self
.
_line_direction
self
.
energy
=
energy
.
at
(
position
=
pos
)
def
at
(
self
,
position
):
def
at
(
self
,
line_
position
):
""" Returns LineEnergy at new position, memorizing the zero point.
Parameters
----------
position : float
line_
position : float
Parameter for the new position on the line direction.
Returns
-------
out : LineEnergy
LineEnergy object at new position with same zero point as `self`.
"""
return
self
.
__class__
(
position
,
return
self
.
__class__
(
line_
position
,
self
.
energy
,
self
.
line_direction
,
zero_point
=
self
.
_zero_point
)
offset
=
self
.
line_position
)
@
property
def
value
(
self
):
return
self
.
energy
.
value
@
property
def
gradient
(
self
):
return
self
.
energy
.
gradient
.
vdot
(
self
.
line_direction
)
def
line_position
(
self
):
return
self
.
_line_position
@
property
def
line_direction
(
self
):
return
self
.
_line_direction
@
property
def
curvature
(
self
):
return
self
.
energy
.
curvature
def
directional_derivative
(
self
):
res
=
self
.
energy
.
gradient
.
vdot
(
self
.
line_direction
)
if
abs
(
res
.
imag
)
/
max
(
abs
(
res
.
real
),
1.
)
>
1e-12
:
print
"directional derivative has non-negligible "
\
"imaginary part:"
,
res
return
res
.
real
nifty/field.py
View file @
25c0b11c
...
...
@@ -112,7 +112,6 @@ class Field(Loggable, Versionable, object):
def
__init__
(
self
,
domain
=
None
,
val
=
None
,
dtype
=
None
,
distribution_strategy
=
None
,
copy
=
False
):
self
.
domain
=
self
.
_parse_domain
(
domain
=
domain
,
val
=
val
)
self
.
domain_axes
=
self
.
_get_axes_tuple
(
self
.
domain
)
...
...
@@ -128,6 +127,7 @@ class Field(Loggable, Versionable, object):
else
:
self
.
set_val
(
new_val
=
val
,
copy
=
copy
)
def
_parse_domain
(
self
,
domain
,
val
=
None
):
if
domain
is
None
:
if
isinstance
(
val
,
Field
):
...
...
@@ -466,7 +466,7 @@ class Field(Loggable, Versionable, object):
return
result_obj
def
power_synthesize
(
self
,
spaces
=
None
,
real_power
=
True
,
real_signal
=
True
,
mean
=
None
,
std
=
None
):
mean
=
None
,
std
=
None
,
distribution_strategy
=
None
):
""" Yields a sampled field with `self`**2 as its power spectrum.
This method draws a Gaussian random field in the harmonic partner
...
...
@@ -541,13 +541,16 @@ class Field(Loggable, Versionable, object):
else
:
result_list
=
[
None
,
None
]
if
distribution_strategy
is
None
:
distribution_strategy
=
gc
[
'default_distribution_strategy'
]
result_list
=
[
self
.
__class__
.
from_random
(
'normal'
,
mean
=
mean
,
std
=
std
,
domain
=
result_domain
,
dtype
=
np
.
complex
,
distribution_strategy
=
self
.
distribution_strategy
)
distribution_strategy
=
distribution_strategy
)
for
x
in
result_list
]
# from now on extract the values from the random fields for further
...
...
@@ -609,39 +612,47 @@ class Field(Loggable, Versionable, object):
# correct variance
if
preserve_gaussian_variance
:
assert
issubclass
(
val
.
dtype
.
type
,
np
.
complexfloating
),
\
"complex input field is needed here"
h
*=
np
.
sqrt
(
2
)
a
*=
np
.
sqrt
(
2
)
if
not
issubclass
(
val
.
dtype
.
type
,
np
.
complexfloating
):
# in principle one must not correct the variance for the fixed
# points of the hermitianization. However, for a complex field
# the input field loses half of its power at its fixed points
# in the `hermitian` part. Hence, here a factor of sqrt(2) is
# also necessary!
# => The hermitianization can be done on a space level since
# either nothing must be done (LMSpace) or ALL points need a
# factor of sqrt(2)
# => use the preserve_gaussian_variance flag in the
# hermitian_decomposition method above.
# This code is for educational purposes:
fixed_points
=
[
domain
[
i
].
hermitian_fixed_points
()
for
i
in
spaces
]
fixed_points
=
[[
fp
]
if
fp
is
None
else
fp
for
fp
in
fixed_points
]
for
product_point
in
itertools
.
product
(
*
fixed_points
):
slice_object
=
np
.
array
((
slice
(
None
),
)
*
len
(
val
.
shape
),
dtype
=
np
.
object
)
for
i
,
sp
in
enumerate
(
spaces
):
point_component
=
product_point
[
i
]
if
point_component
is
None
:
point_component
=
slice
(
None
)
slice_object
[
list
(
domain_axes
[
sp
])]
=
point_component
slice_object
=
tuple
(
slice_object
)
h
[
slice_object
]
/=
np
.
sqrt
(
2
)
a
[
slice_object
]
/=
np
.
sqrt
(
2
)
# The code below should not be needed in practice, since it would
# only ever be called when hermitianizing a purely real field.
# However it might be of educational use and keep us from forgetting
# how these things are done ...
# if not issubclass(val.dtype.type, np.complexfloating):
# # in principle one must not correct the variance for the fixed
# # points of the hermitianization. However, for a complex field
# # the input field loses half of its power at its fixed points
# # in the `hermitian` part. Hence, here a factor of sqrt(2) is
# # also necessary!
# # => The hermitianization can be done on a space level since
# # either nothing must be done (LMSpace) or ALL points need a
# # factor of sqrt(2)
# # => use the preserve_gaussian_variance flag in the
# # hermitian_decomposition method above.
#
# # This code is for educational purposes:
# fixed_points = [domain[i].hermitian_fixed_points()
# for i in spaces]
# fixed_points = [[fp] if fp is None else fp
# for fp in fixed_points]
#
# for product_point in itertools.product(*fixed_points):
# slice_object = np.array((slice(None), )*len(val.shape),
# dtype=np.object)
# for i, sp in enumerate(spaces):
# point_component = product_point[i]
# if point_component is None:
# point_component = slice(None)
# slice_object[list(domain_axes[sp])] = point_component
#
# slice_object = tuple(slice_object)
# h[slice_object] /= np.sqrt(2)
# a[slice_object] /= np.sqrt(2)
return
(
h
,
a
)
def
_spec_to_rescaler
(
self
,
spec
,
result_list
,
power_space_index
):
...
...
@@ -657,7 +668,7 @@ class Field(Loggable, Versionable, object):
result_list
[
0
].
domain_axes
[
power_space_index
])
if
pindex
.
distribution_strategy
is
not
local_distribution_strategy
:
self
.
logger
.
warn
(
raise
AttributeError
(
"The distribution_strategy of pindex does not fit the "
"slice_local distribution strategy of the synthesized field."
)
...
...
@@ -764,14 +775,14 @@ class Field(Loggable, Versionable, object):
dim
"""
if
not
hasattr
(
self
,
'_shape'
):
shape_tuple
=
tuple
(
sp
.
shape
for
sp
in
self
.
domain
)
try
:
global_shape
=
reduce
(
lambda
x
,
y
:
x
+
y
,
shape_tuple
)
except
TypeError
:
global_shape
=
()
return
global
_shape
self
.
_shape
=
global_shape
return
self
.
_shape
@
property
def
dim
(
self
):
...
...
nifty/library/critical_filter/critical_power_curvature.py
View file @
25c0b11c
...
...
@@ -21,7 +21,7 @@ class CriticalPowerCurvature(InvertibleOperatorMixin, EndomorphicOperator):
# ---Overwritten properties and methods---
def
__init__
(
self
,
theta
,
T
,
inverter
=
None
,
preconditioner
=
None
):
def
__init__
(
self
,
theta
,
T
,
inverter
=
None
,
preconditioner
=
None
,
**
kwargs
):
self
.
theta
=
DiagonalOperator
(
theta
.
domain
,
diagonal
=
theta
)
self
.
T
=
T
...
...
@@ -30,7 +30,8 @@ class CriticalPowerCurvature(InvertibleOperatorMixin, EndomorphicOperator):
self
.
_domain
=
self
.
theta
.
domain
super
(
CriticalPowerCurvature
,
self
).
__init__
(
inverter
=
inverter
,
preconditioner
=
preconditioner
)
preconditioner
=
preconditioner
,
**
kwargs
)
def
_times
(
self
,
x
,
spaces
):
return
self
.
T
(
x
)
+
self
.
theta
(
x
)
...
...
nifty/library/critical_filter/critical_power_energy.py
View file @
25c0b11c
...
...
@@ -53,24 +53,28 @@ class CriticalPowerEnergy(Energy):
default : None
"""
# ---Overwritten properties and methods---
def
__init__
(
self
,
position
,
m
,
D
=
None
,
alpha
=
1.0
,
q
=
0.
,
smoothness_prior
=
0.
,
logarithmic
=
True
,
samples
=
3
,
w
=
None
):
super
(
CriticalPowerEnergy
,
self
).
__init__
(
position
=
position
)
self
.
m
=
m
self
.
D
=
D
self
.
samples
=
samples
self
.
smoothness_prior
=
np
.
float
(
smoothness_prior
)
self
.
alpha
=
Field
(
self
.
position
.
domain
,
val
=
alpha
)
self
.
q
=
Field
(
self
.
position
.
domain
,
val
=
q
)
self
.
T
=
SmoothnessOperator
(
domain
=
self
.
position
.
domain
[
0
],
strength
=
s
elf
.
s
moothness_prior
,
strength
=
smoothness_prior
,
logarithmic
=
logarithmic
)
self
.
rho
=
self
.
position
.
domain
[
0
].
rho
self
.
_w
=
w
if
w
is
not
None
else
None
# ---Mandatory properties and methods---
def
at
(
self
,
position
):
return
self
.
__class__
(
position
,
self
.
m
,
D
=
self
.
D
,
alpha
=
self
.
alpha
,
q
=
self
.
q
,
smoothness_prior
=
self
.
smoothness_prior
,
logarithmic
=
self
.
logarithmic
,
w
=
self
.
w
,
samples
=
self
.
samples
)
@
property
...
...
@@ -94,6 +98,16 @@ class CriticalPowerEnergy(Energy):
T
=
self
.
T
)
return
curvature
# ---Added properties and methods---
@
property
def
logarithmic
(
self
):
return
self
.
T
.
logarithmic
@
property
def
smoothness_prior
(
self
):
return
self
.
T
.
strength
@
property
def
w
(
self
):
if
self
.
_w
is
None
:
...
...
nifty/library/wiener_filter/wiener_filter_curvature.py
View file @
25c0b11c
...
...
@@ -22,7 +22,7 @@ class WienerFilterCurvature(InvertibleOperatorMixin, EndomorphicOperator):
"""
def
__init__
(
self
,
R
,
N
,
S
,
inverter
=
None
,
preconditioner
=
None
):
def
__init__
(
self
,
R
,
N
,
S
,
inverter
=
None
,
preconditioner
=
None
,
**
kwargs
):
self
.
R
=
R
self
.
N
=
N
...
...
@@ -32,7 +32,8 @@ class WienerFilterCurvature(InvertibleOperatorMixin, EndomorphicOperator):
self
.
_domain
=
self
.
S
.
domain
super
(
WienerFilterCurvature
,
self
).
__init__
(
inverter
=
inverter
,
preconditioner
=
preconditioner
)
preconditioner
=
preconditioner
,
**
kwargs
)
@
property
def
domain
(
self
):
...
...
nifty/library/wiener_filter/wiener_filter_energy.py
View file @
25c0b11c
...
...
@@ -23,7 +23,7 @@ class WienerFilterEnergy(Energy):
The prior signal covariance in harmonic space.
"""
def
__init__
(
self
,
position
,
d
,
R
,
N
,
S
,
inverter
=
None
):
def
__init__
(
self
,
position
,
d
,
R
,
N
,
S
):
super
(
WienerFilterEnergy
,
self
).
__init__
(
position
=
position
)
self
.
d
=
d
self
.
R
=
R
...
...
@@ -32,7 +32,7 @@ class WienerFilterEnergy(Energy):
def
at
(
self
,
position
):
return
self
.
__class__
(
position
=
position
,
d
=
self
.
d
,
R
=
self
.
R
,
N
=
self
.
N
,
S
=
self
.
S
,
inverter
=
self
.
inverter
)
S
=
self
.
S
)
@
property
@
memo
...
...
@@ -49,6 +49,7 @@ class WienerFilterEnergy(Energy):
def
curvature
(
self
):
return
WienerFilterCurvature
(
R
=
self
.
R
,
N
=
self
.
N
,
S
=
self
.
S
)
@
property
@
memo
def
_Dx
(
self
):
return
self
.
curvature
(
self
.
position
)
...
...
nifty/minimization/descent_minimizer.py
View file @
25c0b11c
...
...
@@ -137,7 +137,7 @@ class DescentMinimizer(Loggable, object):
# compute the the gradient for the current location
gradient
=
energy
.
gradient
gradient_norm
=
gradient
.
vdot
(
gradient
)
gradient_norm
=
gradient
.
norm
(
)
# check if position is at a flat point
if
gradient_norm
==
0
:
...
...
@@ -147,7 +147,6 @@ class DescentMinimizer(Loggable, object):
# current position is encoded in energy object
descent_direction
=
self
.
get_descent_direction
(
energy
)
# compute the step length, which minimizes energy.value along the
# search direction
step_length
,
f_k
,
new_energy
=
\
...
...
@@ -155,8 +154,12 @@ class DescentMinimizer(Loggable, object):
energy
=
energy
,
pk
=
descent_direction
,
f_k_minus_1
=
f_k_minus_1
)
if
f_k_minus_1
is
None
:
delta
=
1e30
else
:
delta
=
abs
(
f_k
-
f_k_minus_1
)
/
max
(
abs
(
f_k
),
abs
(
f_k_minus_1
),
1.
)
f_k_minus_1
=
energy
.
value
tx1
=
energy
.
position
-
new_energy
.
position
# check if new energy value is bigger than old energy value
if
(
new_energy
.
value
-
energy
.
value
)
>
0
:
self
.
logger
.
info
(
"Line search algorithm returned a new energy "
...
...
@@ -165,7 +168,6 @@ class DescentMinimizer(Loggable, object):
energy
=
new_energy
# check convergence
delta
=
abs
(
gradient
).
max
()
*
(
step_length
/
np
.
sqrt
(
gradient_norm
))