Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
de074086
Commit
de074086
authored
Nov 02, 2016
by
theos
Browse files
Improved interface of Prober class.
Fixed a few small bugs.
parent
1c53be26
Changes
19
Show whitespace changes
Inline
Side-by-side
demos/wiener_filter.py
View file @
de074086
from
nifty
import
*
import
plotly.offline
as
pl
import
plotly.graph_objs
as
go
#
import plotly.offline as pl
#
import plotly.graph_objs as go
from
mpi4py
import
MPI
comm
=
MPI
.
COMM_WORLD
...
...
@@ -12,7 +12,7 @@ if __name__ == "__main__":
distribution_strategy
=
'fftw'
s_space
=
RGSpace
([
512
,
512
],
dtype
=
np
.
complex128
)
s_space
=
RGSpace
([
512
,
512
],
dtype
=
np
.
float64
)
fft
=
FFTOperator
(
s_space
)
h_space
=
fft
.
target
[
0
]
p_space
=
PowerSpace
(
h_space
,
distribution_strategy
=
distribution_strategy
)
...
...
@@ -46,8 +46,8 @@ if __name__ == "__main__":
m_data
=
m
.
val
.
get_full_data
().
real
ss_data
=
ss
.
val
.
get_full_data
().
real
if
rank
==
0
:
pl
.
plot
([
go
.
Heatmap
(
z
=
d_data
)],
filename
=
'data.html'
)
pl
.
plot
([
go
.
Heatmap
(
z
=
m_data
)],
filename
=
'map.html'
)
pl
.
plot
([
go
.
Heatmap
(
z
=
ss_data
)],
filename
=
'map_orig.html'
)
#
if rank == 0:
#
pl.plot([go.Heatmap(z=d_data)], filename='data.html')
#
pl.plot([go.Heatmap(z=m_data)], filename='map.html')
#
pl.plot([go.Heatmap(z=ss_data)], filename='map_orig.html')
#
nifty/field.py
View file @
de074086
...
...
@@ -17,7 +17,7 @@ from nifty.random import Random
from
keepers
import
Loggable
class
Field
(
object
,
Loggable
):
class
Field
(
Loggable
,
object
):
# ---Initialization methods---
def
__init__
(
self
,
domain
=
None
,
val
=
None
,
dtype
=
None
,
field_type
=
None
,
...
...
@@ -430,7 +430,7 @@ class Field(object, Loggable):
if
copy
:
new_val
=
new_val
.
copy
()
self
.
_val
=
new_val
return
self
.
_val
return
self
def
get_val
(
self
,
copy
=
False
):
if
copy
:
...
...
nifty/minimization/conjugate_gradient.py
View file @
de074086
...
...
@@ -7,7 +7,7 @@ import numpy as np
from
keepers
import
Loggable
class
ConjugateGradient
(
object
,
Loggable
):
class
ConjugateGradient
(
Loggable
,
object
):
def
__init__
(
self
,
convergence_tolerance
=
1E-4
,
convergence_level
=
3
,
iteration_limit
=
None
,
reset_count
=
None
,
preconditioner
=
None
,
callback
=
None
):
...
...
nifty/minimization/line_searching/line_search.py
View file @
de074086
...
...
@@ -5,7 +5,7 @@ from keepers import Loggable
from
nifty
import
LineEnergy
class
LineSearch
(
object
,
Loggable
):
class
LineSearch
(
Loggable
,
object
):
"""
Class for finding a step size.
"""
...
...
nifty/minimization/quasi_newton_minimizer.py
View file @
de074086
...
...
@@ -9,7 +9,7 @@ from keepers import Loggable
from
.line_searching
import
LineSearchStrongWolfe
class
QuasiNewtonMinimizer
(
object
,
Loggable
):
class
QuasiNewtonMinimizer
(
Loggable
,
object
):
__metaclass__
=
abc
.
ABCMeta
def
__init__
(
self
,
line_searcher
=
LineSearchStrongWolfe
(),
callback
=
None
,
...
...
nifty/operators/endomorphic_operator/endomorphic_operator.py
View file @
de074086
...
...
@@ -60,27 +60,3 @@ class EndomorphicOperator(LinearOperator):
@
abc
.
abstractproperty
def
symmetric
(
self
):
raise
NotImplementedError
def
trace
(
self
):
pass
def
inverse_trace
(
self
):
pass
def
diagonal
(
self
):
pass
def
inverse_diagonal
(
self
):
pass
def
determinant
(
self
):
pass
def
inverse_determinant
(
self
):
pass
def
log_determinant
(
self
):
pass
def
trace_log
(
self
):
pass
nifty/operators/fft_operator/fft_operator.py
View file @
de074086
...
...
@@ -95,7 +95,7 @@ class FFTOperator(LinearOperator):
result_domain
[
spaces
[
0
]]
=
self
.
target
[
0
]
result_field
=
x
.
copy_empty
(
domain
=
result_domain
)
result_field
.
set_val
(
new_val
=
new_val
)
result_field
.
set_val
(
new_val
=
new_val
,
copy
=
False
)
return
result_field
...
...
@@ -118,7 +118,7 @@ class FFTOperator(LinearOperator):
result_domain
[
spaces
[
0
]]
=
self
.
domain
[
0
]
result_field
=
x
.
copy_empty
(
domain
=
result_domain
)
result_field
.
set_val
(
new_val
=
new_val
)
result_field
.
set_val
(
new_val
=
new_val
,
copy
=
False
)
return
result_field
...
...
nifty/operators/fft_operator/transformations/rg_transforms.py
View file @
de074086
...
...
@@ -10,7 +10,7 @@ from keepers import Loggable
pyfftw
=
gdi
.
get
(
'pyfftw'
)
class
Transform
(
object
,
Loggable
):
class
Transform
(
Loggable
,
object
):
"""
A generic fft object without any implementation.
"""
...
...
nifty/operators/fft_operator/transformations/transformation.py
View file @
de074086
...
...
@@ -4,7 +4,7 @@ import abc
from
keepers
import
Loggable
class
Transformation
(
object
,
Loggable
):
class
Transformation
(
Loggable
,
object
):
"""
A generic transformation which defines a static check_codomain
method for all transforms.
...
...
nifty/operators/linear_operator/linear_operator.py
View file @
de074086
...
...
@@ -9,7 +9,7 @@ from nifty.field_types import FieldType
import
nifty.nifty_utilities
as
utilities
class
LinearOperator
(
object
,
Loggable
):
class
LinearOperator
(
Loggable
,
object
):
__metaclass__
=
abc
.
ABCMeta
def
__init__
(
self
):
...
...
nifty/operators/propagator_operator/propagator_operator.py
View file @
de074086
# -*- coding: utf-8 -*-
import
numpy
as
np
from
nifty.minimization
import
ConjugateGradient
from
nifty.nifty_utilities
import
get_default_codomain
from
nifty.field
import
Field
from
nifty.operators
import
EndomorphicOperator
,
\
FFTOperator
...
...
@@ -45,10 +43,8 @@ class PropagatorOperator(EndomorphicOperator):
self
.
_domain
=
N
.
domain
self
.
_likelihood_times
=
lambda
z
:
N
.
inverse_times
(
z
)
fft_S
=
FFTOperator
(
S
.
domain
,
target
=
self
.
_domain
)
self
.
_S_times
=
lambda
z
:
fft_S
(
S
(
fft_S
.
inverse_times
(
z
)))
self
.
_S_inverse_times
=
lambda
z
:
fft_S
(
S
.
inverse_times
(
fft_S
.
inverse_times
(
z
)))
self
.
_S
=
S
self
.
_fft_S
=
FFTOperator
(
self
.
_domain
,
target
=
self
.
_S
.
domain
)
if
preconditioner
is
None
:
preconditioner
=
self
.
_S_times
...
...
@@ -61,8 +57,6 @@ class PropagatorOperator(EndomorphicOperator):
self
.
inverter
=
ConjugateGradient
(
preconditioner
=
self
.
preconditioner
)
self
.
x0
=
None
# ---Mandatory properties and methods---
@
property
...
...
@@ -87,18 +81,44 @@ class PropagatorOperator(EndomorphicOperator):
# ---Added properties and methods---
def
_times
(
self
,
x
,
spaces
,
types
):
if
self
.
x0
is
None
:
x0
=
Field
(
self
.
domain
,
val
=
0.
,
dtype
=
np
.
complex128
)
else
:
x0
=
self
.
x0
def
_S_times
(
self
,
x
,
spaces
=
None
,
types
=
None
):
transformed_x
=
self
.
_fft_S
(
x
,
spaces
=
spaces
,
types
=
types
)
y
=
self
.
_S
(
transformed_x
,
spaces
=
spaces
,
types
=
types
)
transformed_y
=
self
.
_fft_S
.
inverse_times
(
y
,
spaces
=
spaces
,
types
=
types
)
result
=
x
.
copy_empty
()
result
.
set_val
(
transformed_y
,
copy
=
False
)
return
result
def
_S_inverse_times
(
self
,
x
,
spaces
=
None
,
types
=
None
):
transformed_x
=
self
.
_fft_S
(
x
,
spaces
=
spaces
,
types
=
types
)
y
=
self
.
_S
.
inverse_times
(
transformed_x
,
spaces
=
spaces
,
types
=
types
)
transformed_y
=
self
.
_fft_S
.
inverse_times
(
y
,
spaces
=
spaces
,
types
=
types
)
result
=
x
.
copy_empty
()
result
.
set_val
(
transformed_y
,
copy
=
False
)
return
result
def
_times
(
self
,
x
,
spaces
,
types
,
x0
=
None
):
if
x0
is
None
:
x0
=
Field
(
self
.
domain
,
val
=
0.
,
dtype
=
x
.
dtype
)
(
result
,
convergence
)
=
self
.
inverter
(
A
=
self
.
inverse_times
,
b
=
x
,
x0
=
x0
)
self
.
x0
=
result
return
result
def
_inverse_times
(
self
,
x
,
spaces
,
types
):
result
=
self
.
_S_inverse_times
(
x
)
result
+=
self
.
_likelihood_times
(
x
)
pre_result
=
self
.
_S_inverse_times
(
x
,
spaces
,
types
)
pre_result
+=
self
.
_likelihood_times
(
x
)
result
=
x
.
copy_empty
()
result
.
set_val
(
pre_result
,
copy
=
False
)
return
result
nifty/operators/smoothing_operator/smoothing_operator.py
View file @
de074086
...
...
@@ -54,7 +54,7 @@ class SmoothingOperator(EndomorphicOperator):
@
property
def
symmetric
(
self
):
return
Fals
e
return
Tru
e
@
property
def
unitary
(
self
):
...
...
@@ -138,7 +138,10 @@ class SmoothingOperator(EndomorphicOperator):
transformed_x
.
val
.
set_local_data
(
local_transformed_x
,
copy
=
False
)
result
=
Transformator
.
inverse_times
(
transformed_x
,
spaces
=
spaces
)
smoothed_x
=
Transformator
.
inverse_times
(
transformed_x
,
spaces
=
spaces
)
result
=
x
.
copy_empty
()
result
.
set_val
(
smoothed_x
,
copy
=
False
)
return
result
...
...
nifty/probing/diagonal_prober.py
View file @
de074086
# -*- coding: utf-8 -*-
from
nifty.operators
import
EndomorphicOperato
r
from
prober
import
Probe
r
from
operator_prober
import
OperatorProber
__all__
=
[
'DiagonalProber'
,
'InverseDiagonalProber'
,
'AdjointDiagonalProber'
,
'AdjointInverseDiagonalProber'
]
class
DiagonalTypeProber
(
OperatorProber
):
class
DiagonalProber
(
Prober
):
# ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
@
property
def
valid_operator_class
(
self
):
return
EndomorphicOperator
# --- ->Mandatory from Prober---
def
finish_probe
(
self
,
probe
,
pre_result
):
return
probe
[
1
].
conjugate
()
*
pre_result
class
DiagonalProber
(
DiagonalTypeProber
):
# ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
@
property
def
is_inverse
(
self
):
return
False
# --- ->Mandatory from Prober---
def
evaluate_probe
(
self
,
probe
):
""" processes a probe """
return
self
.
operator
.
times
(
probe
[
1
])
class
InverseDiagonalProber
(
DiagonalTypeProber
):
# ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
@
property
def
is_inverse
(
self
):
return
True
# --- ->Mandatory from Prober---
def
evaluate_probe
(
self
,
probe
):
""" processes a probe """
return
self
.
operator
.
inverse_times
(
probe
[
1
])
class
AdjointDiagonalProber
(
DiagonalTypeProber
):
# ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
@
property
def
is_inverse
(
self
):
return
False
# --- ->Mandatory from Prober---
def
evaluate_probe
(
self
,
probe
):
""" processes a probe """
return
self
.
operator
.
adjoint_times
(
probe
[
1
])
class
AdjointInverseDiagonalProber
(
DiagonalTypeProber
):
# ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
@
property
def
is_inverse
(
self
):
return
True
# --- ->Mandatory from Prober---
def
evaluate_probe
(
self
,
probe
):
""" processes a probe """
return
self
.
operator
.
adjoint_inverse_times
(
probe
[
1
])
nifty/probing/operator_prober.py
deleted
100644 → 0
View file @
1c53be26
# -*- coding: utf-8 -*-
import
abc
from
prober
import
Prober
class
OperatorProber
(
Prober
):
# ---Overwritten properties and methods---
def
__init__
(
self
,
operator
,
probe_count
=
8
,
random_type
=
'pm1'
,
distribution_strategy
=
None
,
compute_variance
=
False
):
super
(
OperatorProber
,
self
).
__init__
(
probe_count
=
probe_count
,
random_type
=
random_type
,
compute_variance
=
compute_variance
)
if
not
isinstance
(
operator
,
self
.
valid_operator_class
):
raise
TypeError
(
"Operator must be an instance of %s"
%
str
(
self
.
valid_operator_class
))
self
.
_operator
=
operator
# ---Mandatory properties and methods---
@
property
def
domain
(
self
):
if
self
.
is_inverse
:
return
self
.
operator
.
target
else
:
return
self
.
operator
.
domain
@
property
def
field_type
(
self
):
if
self
.
is_inverse
:
return
self
.
operator
.
field_type_target
else
:
return
self
.
operator
.
field_type
@
property
def
distribution_strategy
(
self
):
return
self
.
operator
.
distribution_strategy
# ---Added properties and methods---
@
abc
.
abstractproperty
def
is_inverse
(
self
):
raise
NotImplementedError
@
abc
.
abstractproperty
def
valid_operator_class
(
self
):
raise
NotImplementedError
@
property
def
operator
(
self
):
return
self
.
_operator
nifty/probing/prober.py
View file @
de074086
...
...
@@ -4,6 +4,8 @@ import abc
import
numpy
as
np
from
nifty.field_types
import
FieldType
from
nifty.spaces
import
Space
from
nifty.field
import
Field
from
d2o
import
STRATEGIES
as
DISTRIBUTION_STRATEGIES
...
...
@@ -12,28 +14,67 @@ from d2o import STRATEGIES as DISTRIBUTION_STRATEGIES
class
Prober
(
object
):
__metaclass__
=
abc
.
ABCMeta
def
__init__
(
self
,
probe_count
=
8
,
random_type
=
'pm1'
,
compute_variance
=
False
):
def
__init__
(
self
,
domain
=
None
,
field_type
=
None
,
distribution_strategy
=
None
,
probe_count
=
8
,
random_type
=
'pm1'
,
compute_variance
=
False
):
self
.
domain
=
domain
self
.
field_type
=
field_type
self
.
distribution_strategy
=
distribution_strategy
self
.
probe_count
=
probe_count
self
.
random_type
=
random_type
self
.
compute_variance
=
bool
(
compute_variance
)
def
_parse_domain
(
self
,
domain
):
if
domain
is
None
:
domain
=
()
elif
isinstance
(
domain
,
Space
):
domain
=
(
domain
,)
elif
not
isinstance
(
domain
,
tuple
):
domain
=
tuple
(
domain
)
for
d
in
domain
:
if
not
isinstance
(
d
,
Space
):
raise
TypeError
(
"Given object contains something that is not a "
"nifty.space."
)
return
domain
def
_parse_field_type
(
self
,
field_type
):
if
field_type
is
None
:
field_type
=
()
elif
isinstance
(
field_type
,
FieldType
):
field_type
=
(
field_type
,)
elif
not
isinstance
(
field_type
,
tuple
):
field_type
=
tuple
(
field_type
)
for
ft
in
field_type
:
if
not
isinstance
(
ft
,
FieldType
):
raise
TypeError
(
"Given object is not a nifty.FieldType."
)
return
field_type
# ---Properties---
@
abc
.
abstract
property
@
property
def
domain
(
self
):
raise
NotImplementedError
return
self
.
_domain
@
domain
.
setter
def
domain
(
self
,
domain
):
self
.
_domain
=
self
.
_parse_domain
(
domain
)
@
abc
.
abstract
property
@
property
def
field_type
(
self
):
raise
NotImplementedError
return
self
.
_field_type
@
field_type
.
setter
def
field_type
(
self
,
field_type
):
self
.
_field_type
=
self
.
_parse_field_type
(
field_type
)
@
abc
.
abstract
property
@
property
def
distribution_strategy
(
self
):
r
aise
NotImplementedError
r
eturn
self
.
_distribution_strategy
@
distribution_strategy
.
setter
def
distribution_strategy
(
self
,
distribution_strategy
):
...
...
@@ -65,14 +106,14 @@ class Prober(object):
# ---Probing methods---
def
probing_run
(
self
):
def
probing_run
(
self
,
callee
):
""" controls the generation, evaluation and finalization of probes """
sum_of_probes
=
0
sum_of_squares
=
0
for
index
in
xrange
(
self
.
probe_count
):
current_probe
=
self
.
get_probe
(
index
)
pre_result
=
self
.
process_probe
(
current_probe
,
index
)
pre_result
=
self
.
process_probe
(
callee
,
current_probe
,
index
)
result
=
self
.
finish_probe
(
current_probe
,
pre_result
)
sum_of_probes
+=
result
...
...
@@ -95,13 +136,13 @@ class Prober(object):
uid
=
np
.
random
.
randint
(
1e18
)
return
(
uid
,
f
)
def
process_probe
(
self
,
probe
,
index
):
return
self
.
evaluate_probe
(
probe
)
def
process_probe
(
self
,
callee
,
probe
,
index
):
""" layer of abstraction for potential result-caching/recycling """
return
self
.
evaluate_probe
(
callee
,
probe
[
1
])
@
abc
.
abstractmethod
def
evaluate_probe
(
self
,
probe
):
def
evaluate_probe
(
self
,
callee
,
probe
,
**
kwargs
):
""" processes a probe """
r
aise
NotImplementedError
r
eturn
callee
(
probe
,
**
kwargs
)
@
abc
.
abstractmethod
def
finish_probe
(
self
,
probe
,
pre_result
):
...
...
@@ -119,5 +160,5 @@ class Prober(object):
return
(
mean_of_probes
,
variance
)
def
__call__
(
self
):
return
self
.
probe
(
)
def
__call__
(
self
,
callee
):
return
self
.
prob
ing_run
(
calle
e
)
nifty/probing/trace_prober.py
View file @
de074086
# -*- coding: utf-8 -*-
from
nifty.operators
import
EndomorphicOperato
r
from
prober
import
Probe
r
from
operator_prober
import
OperatorProber
__all__
=
[
'TraceProber'
,
'InverseTraceProber'
,
'AdjointTraceProber'
,
'AdjointInverseTraceProber'
]
class
TraceTypeProber
(
OperatorProber
):
class
TraceProber
(
Prober
):
# ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
@
property
def
valid_operator_class
(
self
):
return
EndomorphicOperator
# --- ->Mandatory from Prober---
def
finish_probe
(
self
,
probe
,
pre_result
):
return
probe
[
1
].
conjugate
().
weight
(
power
=-
1
).
dot
(
pre_result
)
class
TraceProber
(
TraceTypeProber
):
# ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
@
property
def
is_inverse
(
self
):
return
False
# --- ->Mandatory from Prober---
def
evaluate_probe
(
self
,
probe
):
""" processes a probe """
return
self
.
operator
.
times
(
probe
[
1
])
class
InverseTraceProber
(
TraceTypeProber
):
# ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
@
property
def
is_inverse
(
self
):
return
True
# --- ->Mandatory from Prober---
def
evaluate_probe
(
self
,
probe
):
""" processes a probe """
return
self
.
operator
.
inverse_times
(
probe
[
1
])
class
AdjointTraceProber
(
TraceTypeProber
):
# ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
@
property
def
is_inverse
(
self
):
return
False
# --- ->Mandatory from Prober---
def
evaluate_probe
(
self
,
probe
):
""" processes a probe """