Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
d5faac2e
Commit
d5faac2e
authored
Dec 11, 2015
by
Ultima
Browse files
Started implementing different minimizers.
parent
c2fec153
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
demos/demo_wf2.py
View file @
d5faac2e
...
...
@@ -50,10 +50,11 @@ from __future__ import division
from
nifty
import
*
# version 0.8.0
from
nifty.operators.nifty_minimization
import
steepest_descent_new
# some signal space; e.g., a two-dimensional regular grid
x_space
=
rg_space
([
128
,
128
])
# define signal space
x_space
=
rg_space
([
256
,
256
])
# define signal space
k_space
=
x_space
.
get_codomain
()
# get conjugate space
...
...
@@ -76,15 +77,29 @@ j = R.adjoint_times(N.inverse_times(d)) # define inform
D
=
propagator_operator
(
S
=
S
,
N
=
N
,
R
=
R
)
# define information propagator
def
energy
(
x
):
DIx
=
D
.
inverse_times
(
x
)
H
=
0.5
*
DIx
.
dot
(
x
)
-
j
.
dot
(
x
)
return
H
def
gradient
(
x
):
DIx
=
D
.
inverse_times
(
x
)
g
=
DIx
-
j
return
g
def
eggs
(
x
):
"""
Calculation of the information Hamiltonian and its gradient.
"""
DIx
=
D
.
inverse_times
(
x
)
H
=
0.5
*
DIx
.
dot
(
x
)
-
j
.
dot
(
x
)
# compute information Hamiltonian
g
=
DIx
-
j
# compute its gradient
return
H
,
g
# DIx = D.inverse_times(x)
# H = 0.5 * DIx.dot(x) - j.dot(x) # compute information Hamiltonian
# g = DIx - j # compute its gradient
# return H, g
return
energy
(
x
),
gradient
(
x
)
m
=
field
(
x_space
,
codomain
=
k_space
)
# reconstruct map
...
...
@@ -92,6 +107,8 @@ m = field(x_space, codomain=k_space) # reconstruct
#with PyCallGraph(output=graphviz, config=config):
m
,
convergence
=
steepest_descent
(
eggs
=
eggs
,
note
=
True
)(
m
,
tol
=
1E-3
,
clevel
=
3
)
m
=
field
(
x_space
,
codomain
=
k_space
)
m
,
convergence
=
steepest_descent_new
(
energy
,
gradient
,
note
=
True
)(
m
,
tol
=
1E-3
,
clevel
=
3
)
#s.plot(title="signal") # plot signal
#d_ = field(x_space, val=d.val, target=k_space)
#d_.plot(title="data", vmin=s.min(), vmax=s.max()) # plot data
...
...
nifty_core.py
View file @
d5faac2e
...
...
@@ -2468,6 +2468,9 @@ class field(object):
return
np
.
sum
(
result
,
axis
=
axis
)
def
vdot
(
self
,
*
args
,
**
kwargs
):
return
self
.
dot
(
*
args
,
**
kwargs
)
def
outer_dot
(
self
,
x
=
1
,
axis
=
None
):
# Use the fact that self.val is a numpy array of dtype np.object
...
...
nifty_mpi_data.py
View file @
d5faac2e
...
...
@@ -163,6 +163,8 @@ class distributed_data_object(object):
new_copy
.
__dict__
[
key
]
=
value
else
:
new_copy
.
__dict__
[
key
]
=
np
.
empty_like
(
value
)
new_copy
.
index
=
d2o_librarian
.
register
(
new_copy
)
return
new_copy
def
copy
(
self
,
dtype
=
None
,
distribution_strategy
=
None
,
**
kwargs
):
...
...
@@ -503,7 +505,7 @@ class distributed_data_object(object):
# local_vdot_list = self.distributor._allgather(local_vdot)
# global_vdot = np.result_type(self.dtype,
# other.dtype).type(np.sum(local_vdot_list))
return
global_vdot
return
global_vdot
[
0
]
def
__getitem__
(
self
,
key
):
return
self
.
get_data
(
key
)
...
...
@@ -743,13 +745,19 @@ class distributed_data_object(object):
local_counts
=
np
.
bincount
(
self
.
get_local_data
().
flatten
(),
weights
=
local_weights
,
minlength
=
minlength
)
if
self
.
distribution_strategy
==
'not'
:
return
local_counts
else
:
counts
=
np
.
empty_like
(
local_counts
)
self
.
distributor
.
_Allreduce_sum
(
local_counts
,
counts
)
# list_of_counts = self.distributor._allgather(local_counts)
# counts = np.sum(list_of_counts, axis=0)
# self.distributor._Allreduce_sum(local_counts, counts)
# Potentially faster, but buggy. <- If np.binbount yields
# inconsistent datatypes because of empty arrays on certain nodes,
# the Allreduce produces non-sense results.
list_of_counts
=
self
.
distributor
.
_allgather
(
local_counts
)
counts
=
np
.
sum
(
list_of_counts
,
axis
=
0
)
return
counts
def
where
(
self
):
...
...
@@ -1764,9 +1772,7 @@ class _slicing_distributor(distributor):
# Check which case we got:
(
found
,
found_boolean
)
=
_infer_key_type
(
key
)
comm
=
self
.
comm
if
local_keys
is
False
:
return
self
.
_collect_data_primitive
(
data
,
key
,
found
,
found_boolean
,
**
kwargs
)
...
...
@@ -1788,7 +1794,6 @@ class _slicing_distributor(distributor):
else
:
index_list
=
comm
.
allgather
(
key
.
index
)
key_list
=
map
(
lambda
z
:
d2o_librarian
[
z
],
index_list
)
i
=
0
for
temp_key
in
key_list
:
# build the locally fed d2o
...
...
@@ -1844,7 +1849,6 @@ class _slicing_distributor(distributor):
if
list_key
==
[]:
raise
ValueError
(
about
.
_errors
.
cstring
(
"ERROR: key == [] is an unsupported key!"
))
local_list_key
=
self
.
_advanced_index_decycler
(
list_key
)
local_result
=
data
[
local_list_key
]
global_result
=
distributed_data_object
(
...
...
@@ -1922,8 +1926,8 @@ class _slicing_distributor(distributor):
# for i in xrange(len(result) - 1)):
# raise ValueError(about._errors.cstring(
# "ERROR: The first dimemnsion of list_key must be sorted!"))
result
=
[
result
]
result
=
[
result
]
for
ii
in
xrange
(
1
,
len
(
from_list_key
)):
current
=
from_list_key
[
ii
]
if
np
.
isscalar
(
current
):
...
...
@@ -2174,10 +2178,11 @@ class _slicing_distributor(distributor):
# If the distributor is not exactly the same, check if the
# geometry matches if it is a slicing distributor
# -> comm and local shapes
elif
isinstance
(
data_object
.
distributor
,
_slicing_distributor
):
if
(
self
.
comm
is
data_object
.
distributor
.
comm
)
and
\
np
.
all
(
self
.
all_local_slices
==
data_object
.
distributor
.
all_local_slices
):
elif
(
isinstance
(
data_object
.
distributor
,
_slicing_distributor
)
and
(
self
.
comm
is
data_object
.
distributor
.
comm
)
and
(
np
.
all
(
self
.
all_local_slices
==
data_object
.
distributor
.
all_local_slices
))):
extracted_data
=
data_object
.
data
else
:
...
...
@@ -2925,6 +2930,9 @@ class d2o_iter(object):
else
:
raise
StopIteration
()
def
initialize_current_local_data
(
self
):
raise
NotImplementedError
def
update_current_local_data
(
self
):
raise
NotImplementedError
...
...
nifty_simple_math.py
View file @
d5faac2e
...
...
@@ -26,12 +26,27 @@ import numpy as np
from
keepers
import
about
def
vdot
(
x
,
y
):
try
:
return
x
.
vdot
(
y
)
except
AttributeError
:
pass
try
:
return
y
.
vdot
(
x
)
except
AttributeError
:
pass
return
np
.
vdot
(
x
,
y
)
def
_math_helper
(
x
,
function
):
try
:
return
x
.
apply_scalar_function
(
function
)
except
(
AttributeError
):
return
function
(
np
.
array
(
x
))
def
cos
(
x
):
"""
Returns the cos of a given object.
...
...
@@ -60,6 +75,7 @@ def cos(x):
"""
return
_math_helper
(
x
,
np
.
cos
)
def
sin
(
x
):
"""
Returns the sine of a given object.
...
...
@@ -89,6 +105,7 @@ def sin(x):
"""
return
_math_helper
(
x
,
np
.
sin
)
def
cosh
(
x
):
"""
Returns the hyperbolic cosine of a given object.
...
...
@@ -118,6 +135,7 @@ def cosh(x):
"""
return
_math_helper
(
x
,
np
.
cosh
)
def
sinh
(
x
):
"""
Returns the hyperbolic sine of a given object.
...
...
@@ -147,6 +165,7 @@ def sinh(x):
"""
return
_math_helper
(
x
,
np
.
sinh
)
def
tan
(
x
):
"""
Returns the tangent of a given object.
...
...
@@ -176,6 +195,7 @@ def tan(x):
"""
return
_math_helper
(
x
,
np
.
tan
)
def
tanh
(
x
):
"""
Returns the hyperbolic tangent of a given object.
...
...
@@ -322,6 +342,7 @@ def arcsinh(x):
"""
return
_math_helper
(
x
,
np
.
arcsinh
)
def
arctan
(
x
):
"""
Returns the arctan of a given object.
...
...
@@ -350,6 +371,7 @@ def arctan(x):
"""
return
_math_helper
(
x
,
np
.
arctan
)
def
arctanh
(
x
):
"""
Returns the hyperbolic arc tangent of a given object.
...
...
@@ -378,6 +400,7 @@ def arctanh(x):
"""
return
_math_helper
(
x
,
np
.
arctanh
)
def
sqrt
(
x
):
"""
Returns the square root of a given object.
...
...
@@ -402,6 +425,7 @@ def sqrt(x):
"""
return
_math_helper
(
x
,
np
.
sqrt
)
def
exp
(
x
):
"""
Returns the exponential of a given object.
...
...
@@ -430,7 +454,8 @@ def exp(x):
"""
return
_math_helper
(
x
,
np
.
exp
)
def
log
(
x
,
base
=
None
):
def
log
(
x
,
base
=
None
):
"""
Returns the logarithm with respect to a specified base.
...
...
@@ -462,11 +487,12 @@ def log(x,base=None):
return
_math_helper
(
x
,
np
.
log
)
base
=
np
.
array
(
base
)
if
(
np
.
all
(
base
>
0
)
):
if
np
.
all
(
base
>
0
):
return
_math_helper
(
x
,
np
.
log
)
/
np
.
log
(
base
)
else
:
raise
ValueError
(
about
.
_errors
.
cstring
(
"ERROR: invalid input basis."
))
def
conjugate
(
x
):
"""
Computes the complex conjugate of a given object.
...
...
@@ -482,9 +508,3 @@ def conjugate(x):
The complex conjugated object.
"""
return
_math_helper
(
x
,
np
.
conjugate
)
##---------------------------------
\ No newline at end of file
nifty_utilities.py
View file @
d5faac2e
...
...
@@ -71,7 +71,7 @@ def _hermitianize_inverter(x):
return
y
def
direct_dot
(
x
,
y
):
def
direct_
v
dot
(
x
,
y
):
# the input could be fields. Try to extract the data
try
:
x
=
x
.
get_val
()
...
...
operators/nifty_los.py
View file @
d5faac2e
...
...
@@ -42,6 +42,9 @@ class los_response(operator):
starts
,
ends
,
sigmas_low
,
sigmas_up
,
zero_point
)
self
.
_local_shape
=
self
.
_init_local_shape
()
self
.
_set_extractor_d2o
()
self
.
local_weights_and_indices
=
self
.
_compute_weights_and_indices
()
self
.
number_of_los
=
len
(
self
.
sigmas_low
)
...
...
@@ -212,7 +215,7 @@ class los_response(operator):
"ERROR: The space's datamodel is not supported:"
+
str
(
self
.
domain
.
datamodel
)))
def
_
ge
t_local_shape
(
self
):
def
_
ini
t_local_shape
(
self
):
if
self
.
domain
.
datamodel
==
'np'
:
return
self
.
domain
.
get_shape
()
elif
self
.
domain
.
datamodel
in
STRATEGIES
[
'not'
]:
...
...
@@ -225,6 +228,9 @@ class los_response(operator):
skip_parsing
=
True
)
return
dummy_d2o
.
distributor
.
local_shape
def
_get_local_shape
(
self
):
return
self
.
_local_shape
def
_compute_weights_and_indices
(
self
):
# compute the local pixel coordinates for the starts and ends
localized_pixel_starts
=
self
.
_convert_physical_to_indices
(
self
.
starts
)
...
...
@@ -258,11 +264,7 @@ class los_response(operator):
return
local_indices_and_weights_list
def
_multiply
(
self
,
input_field
):
# extract the local data array from the input field
try
:
local_input_data
=
input_field
.
val
.
data
except
AttributeError
:
local_input_data
=
input_field
.
val
local_input_data
=
self
.
_multiply_preprocessing
(
input_field
)
local_result
=
np
.
zeros
(
self
.
number_of_los
,
dtype
=
self
.
target
.
dtype
)
...
...
@@ -272,19 +274,33 @@ class los_response(operator):
local_result
[
los_index
]
+=
\
np
.
sum
(
local_input_data
[
indices
]
*
weights
)
if
self
.
domain
.
datamodel
==
'np'
:
global_result
=
local_result
elif
self
.
domain
.
datamodel
is
STRATEGIES
[
'not'
]:
global_result
=
local_result
if
self
.
domain
.
datamodel
in
STRATEGIES
[
'slicing'
]:
global_result
=
np
.
empty_like
(
local_result
)
self
.
domain
.
comm
.
Allreduce
(
local_result
,
global_result
,
op
=
MPI
.
SUM
)
global_result
=
self
.
_multiply_postprocessing
(
local_result
)
result_field
=
field
(
self
.
target
,
val
=
global_result
,
codomain
=
self
.
cotarget
)
return
result_field
def
_multiply_preprocessing
(
self
,
input_field
):
if
self
.
domain
.
datamodel
==
'np'
:
local_input_data
=
input_field
.
val
elif
self
.
domain
.
datamodel
in
STRATEGIES
[
'not'
]:
local_input_data
=
input_field
.
val
.
data
elif
self
.
domain
.
datamodel
in
STRATEGIES
[
'slicing'
]:
extractor
=
self
.
_extractor_d2o
.
distributor
.
extract_local_data
local_input_data
=
extractor
(
input_field
.
val
)
return
local_input_data
def
_multiply_postprocessing
(
self
,
local_result
):
if
self
.
domain
.
datamodel
==
'np'
:
global_result
=
local_result
elif
self
.
domain
.
datamodel
in
STRATEGIES
[
'not'
]:
global_result
=
local_result
elif
self
.
domain
.
datamodel
in
STRATEGIES
[
'slicing'
]:
global_result
=
np
.
empty_like
(
local_result
)
self
.
domain
.
comm
.
Allreduce
(
local_result
,
global_result
,
op
=
MPI
.
SUM
)
return
global_result
def
_adjoint_multiply
(
self
,
input_field
):
# get the full data as np.ndarray from the input field
try
:
...
...
@@ -321,14 +337,53 @@ class los_response(operator):
return
result_field
def
_improve_slicing
(
self
):
if
self
.
domain
.
datamodel
not
in
STRATEGIES
[
'slicing'
]:
raise
ValueError
(
about
.
_errors
.
cstring
(
"ERROR: distribution strategy of domain is not a "
+
"slicing one."
))
comm
=
self
.
domain
.
comm
local_weight
=
np
.
sum
(
[
len
(
los
[
2
])
for
los
in
self
.
local_weights_and_indices
])
local_length
=
self
.
_get_local_shape
()[
0
]
weights
=
comm
.
allgather
(
local_weight
)
lengths
=
comm
.
allgather
(
local_length
)
optimized_lengths
=
self
.
_length_equilibrator
(
lengths
,
weights
)
new_local_shape
=
list
(
self
.
_local_shape
)
new_local_shape
[
0
]
=
optimized_lengths
[
comm
.
rank
]
self
.
_local_shape
=
tuple
(
new_local_shape
)
self
.
_set_extractor_d2o
()
self
.
local_weights_and_indices
=
self
.
_compute_weights_and_indices
()
def
_length_equilibrator
(
self
,
lengths
,
weights
):
lengths
=
np
.
array
(
lengths
,
dtype
=
np
.
float
)
weights
=
np
.
array
(
weights
,
dtype
=
np
.
float
)
number_of_nodes
=
len
(
lengths
)
cs_lengths
=
np
.
append
(
0
,
np
.
cumsum
(
lengths
))
cs_weights
=
np
.
append
(
0
,
np
.
cumsum
(
weights
))
total_weight
=
cs_weights
[
-
1
]
equiweights
=
np
.
linspace
(
0
,
total_weight
,
number_of_nodes
+
1
)
equiweight_distances
=
np
.
interp
(
equiweights
,
cs_weights
,
cs_lengths
)
equiweight_lengths
=
np
.
diff
(
np
.
floor
(
equiweight_distances
))
return
equiweight_lengths
def
_set_extractor_d2o
(
self
):
if
self
.
domain
.
datamodel
in
STRATEGIES
[
'slicing'
]:
temp_d2o
=
self
.
domain
.
cast
()
extractor
=
temp_d2o
.
copy_empty
(
local_shape
=
self
.
_local_shape
,
distribution_strategy
=
'freeform'
)
self
.
_extractor_d2o
=
extractor
else
:
self
.
_extractor_d2o
=
None
operators/nifty_minimization.py
View file @
d5faac2e
This diff is collapsed.
Click to expand it.
operators/nifty_probing.py
View file @
d5faac2e
...
...
@@ -24,7 +24,7 @@ from __future__ import division
from
nifty.keepers
import
about
from
nifty.nifty_core
import
space
,
\
field
from
nifty.nifty_utilities
import
direct_dot
from
nifty.nifty_utilities
import
direct_
v
dot
...
...
@@ -468,7 +468,7 @@ class trace_prober(_specialized_prober):
**
kwargs
)
def
_probing_function
(
self
,
probe
):
return
direct_dot
(
probe
.
conjugate
(),
self
.
operator
.
times
(
probe
))
return
direct_
v
dot
(
probe
.
conjugate
(),
self
.
operator
.
times
(
probe
))
class
inverse_trace_prober
(
_specialized_prober
):
...
...
@@ -478,7 +478,7 @@ class inverse_trace_prober(_specialized_prober):
**
kwargs
)
def
_probing_function
(
self
,
probe
):
return
direct_dot
(
probe
.
conjugate
(),
return
direct_
v
dot
(
probe
.
conjugate
(),
self
.
operator
.
inverse_times
(
probe
))
...
...
test/test_nifty_mpi_data.py
View file @
d5faac2e
...
...
@@ -874,10 +874,9 @@ class Test_list_get_set_data(unittest.TestCase):
(
a
,
obj
)
=
generate_data
(
global_shape
,
dtype
,
distribution_strategy_1
)
w
=
np
.
where
(
a
>
30
)
w
=
np
.
where
(
a
>
28
)
p
=
obj
.
copy
(
distribution_strategy
=
distribution_strategy_2
)
wo
=
(
p
>
30
).
where
()
wo
=
(
p
>
28
).
where
()
assert_equal
(
obj
[
w
].
get_full_data
(),
a
[
w
])
assert_equal
(
obj
[
wo
].
get_full_data
(),
a
[
w
])
...
...
@@ -903,7 +902,7 @@ class Test_list_get_set_data(unittest.TestCase):
assert_equal
(
obj
[
wo
].
get_full_data
(),
a
[
w
])
#############################################################################
#
#############################################################################
@
parameterized
.
expand
(
itertools
.
product
(
...
...
@@ -1601,22 +1600,23 @@ class Test_comparisons(unittest.TestCase):
class
Test_special_methods
(
unittest
.
TestCase
):
@
parameterized
.
expand
(
all_distribution_strategies
,
@
parameterized
.
expand
(
itertools
.
product
(
all_distribution_strategies
,
all_distribution_strategies
),
testcase_func_name
=
custom_name_func
)
def
test_bincount
(
self
,
distribution_strategy
):
global_shape
=
(
8
0
,)
def
test_bincount
(
self
,
distribution_strategy
_1
,
distribution_strategy_2
):
global_shape
=
(
1
0
,)
dtype
=
np
.
dtype
(
'int'
)
dtype_weights
=
np
.
dtype
(
'float'
)
(
a
,
obj
)
=
generate_data
(
global_shape
,
dtype
,
distribution_strategy
)
distribution_strategy
_1
)
a
=
abs
(
a
)
obj
=
abs
(
obj
)
(
b
,
p
)
=
generate_data
(
global_shape
,
dtype_weights
,
distribution_strategy
)
distribution_strategy
_2
)
b
**=
2
p
**=
2
assert_equal
(
obj
.
bincount
(
weights
=
p
),
np
.
bincount
(
a
,
weights
=
b
))
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment