Commit d5faac2e authored by Ultima's avatar Ultima

Started implementing different minimizers.

parent c2fec153
......@@ -50,10 +50,11 @@ from __future__ import division
from nifty import * # version 0.8.0
from nifty.operators.nifty_minimization import steepest_descent_new
# some signal space; e.g., a two-dimensional regular grid
x_space = rg_space([128, 128]) # define signal space
x_space = rg_space([256, 256]) # define signal space
k_space = x_space.get_codomain() # get conjugate space
......@@ -76,15 +77,29 @@ j = R.adjoint_times(N.inverse_times(d)) # define inform
D = propagator_operator(S=S, N=N, R=R) # define information propagator
def energy(x):
DIx = D.inverse_times(x)
H = 0.5 * DIx.dot(x) - j.dot(x)
return H
def gradient(x):
DIx = D.inverse_times(x)
g = DIx - j
return g
def eggs(x):
"""
Calculation of the information Hamiltonian and its gradient.
"""
DIx = D.inverse_times(x)
H = 0.5 * DIx.dot(x) - j.dot(x) # compute information Hamiltonian
g = DIx - j # compute its gradient
return H, g
# DIx = D.inverse_times(x)
# H = 0.5 * DIx.dot(x) - j.dot(x) # compute information Hamiltonian
# g = DIx - j # compute its gradient
# return H, g
return energy(x), gradient(x)
m = field(x_space, codomain=k_space) # reconstruct map
......@@ -92,6 +107,8 @@ m = field(x_space, codomain=k_space) # reconstruct
#with PyCallGraph(output=graphviz, config=config):
m, convergence = steepest_descent(eggs=eggs, note=True)(m, tol=1E-3, clevel=3)
m = field(x_space, codomain=k_space)
m, convergence = steepest_descent_new(energy, gradient, note=True)(m, tol=1E-3, clevel=3)
#s.plot(title="signal") # plot signal
#d_ = field(x_space, val=d.val, target=k_space)
#d_.plot(title="data", vmin=s.min(), vmax=s.max()) # plot data
......
......@@ -2468,6 +2468,9 @@ class field(object):
return np.sum(result, axis=axis)
def vdot(self, *args, **kwargs):
return self.dot(*args, **kwargs)
def outer_dot(self, x=1, axis=None):
# Use the fact that self.val is a numpy array of dtype np.object
......
......@@ -163,6 +163,8 @@ class distributed_data_object(object):
new_copy.__dict__[key] = value
else:
new_copy.__dict__[key] = np.empty_like(value)
new_copy.index = d2o_librarian.register(new_copy)
return new_copy
def copy(self, dtype=None, distribution_strategy=None, **kwargs):
......@@ -503,7 +505,7 @@ class distributed_data_object(object):
# local_vdot_list = self.distributor._allgather(local_vdot)
# global_vdot = np.result_type(self.dtype,
# other.dtype).type(np.sum(local_vdot_list))
return global_vdot
return global_vdot[0]
def __getitem__(self, key):
return self.get_data(key)
......@@ -743,13 +745,19 @@ class distributed_data_object(object):
local_counts = np.bincount(self.get_local_data().flatten(),
weights=local_weights,
minlength=minlength)
if self.distribution_strategy == 'not':
return local_counts
else:
counts = np.empty_like(local_counts)
self.distributor._Allreduce_sum(local_counts, counts)
# list_of_counts = self.distributor._allgather(local_counts)
# counts = np.sum(list_of_counts, axis=0)
# self.distributor._Allreduce_sum(local_counts, counts)
# Potentially faster, but buggy. <- If np.binbount yields
# inconsistent datatypes because of empty arrays on certain nodes,
# the Allreduce produces non-sense results.
list_of_counts = self.distributor._allgather(local_counts)
counts = np.sum(list_of_counts, axis=0)
return counts
def where(self):
......@@ -1764,9 +1772,7 @@ class _slicing_distributor(distributor):
# Check which case we got:
(found, found_boolean) = _infer_key_type(key)
comm = self.comm
if local_keys is False:
return self._collect_data_primitive(data, key, found,
found_boolean, **kwargs)
......@@ -1788,7 +1794,6 @@ class _slicing_distributor(distributor):
else:
index_list = comm.allgather(key.index)
key_list = map(lambda z: d2o_librarian[z], index_list)
i = 0
for temp_key in key_list:
# build the locally fed d2o
......@@ -1844,7 +1849,6 @@ class _slicing_distributor(distributor):
if list_key == []:
raise ValueError(about._errors.cstring(
"ERROR: key == [] is an unsupported key!"))
local_list_key = self._advanced_index_decycler(list_key)
local_result = data[local_list_key]
global_result = distributed_data_object(
......@@ -1922,8 +1926,8 @@ class _slicing_distributor(distributor):
# for i in xrange(len(result) - 1)):
# raise ValueError(about._errors.cstring(
# "ERROR: The first dimemnsion of list_key must be sorted!"))
result = [result]
result = [result]
for ii in xrange(1, len(from_list_key)):
current = from_list_key[ii]
if np.isscalar(current):
......@@ -2174,10 +2178,11 @@ class _slicing_distributor(distributor):
# If the distributor is not exactly the same, check if the
# geometry matches if it is a slicing distributor
# -> comm and local shapes
elif isinstance(data_object.distributor, _slicing_distributor):
if (self.comm is data_object.distributor.comm) and \
np.all(self.all_local_slices ==
data_object.distributor.all_local_slices):
elif (isinstance(data_object.distributor,
_slicing_distributor) and
(self.comm is data_object.distributor.comm) and
(np.all(self.all_local_slices ==
data_object.distributor.all_local_slices))):
extracted_data = data_object.data
else:
......@@ -2925,6 +2930,9 @@ class d2o_iter(object):
else:
raise StopIteration()
def initialize_current_local_data(self):
raise NotImplementedError
def update_current_local_data(self):
raise NotImplementedError
......
......@@ -26,12 +26,27 @@ import numpy as np
from keepers import about
def vdot(x, y):
try:
return x.vdot(y)
except AttributeError:
pass
try:
return y.vdot(x)
except AttributeError:
pass
return np.vdot(x, y)
def _math_helper(x, function):
try:
return x.apply_scalar_function(function)
except(AttributeError):
return function(np.array(x))
def cos(x):
"""
Returns the cos of a given object.
......@@ -60,6 +75,7 @@ def cos(x):
"""
return _math_helper(x, np.cos)
def sin(x):
"""
Returns the sine of a given object.
......@@ -89,6 +105,7 @@ def sin(x):
"""
return _math_helper(x, np.sin)
def cosh(x):
"""
Returns the hyperbolic cosine of a given object.
......@@ -118,6 +135,7 @@ def cosh(x):
"""
return _math_helper(x, np.cosh)
def sinh(x):
"""
Returns the hyperbolic sine of a given object.
......@@ -147,6 +165,7 @@ def sinh(x):
"""
return _math_helper(x, np.sinh)
def tan(x):
"""
Returns the tangent of a given object.
......@@ -176,6 +195,7 @@ def tan(x):
"""
return _math_helper(x, np.tan)
def tanh(x):
"""
Returns the hyperbolic tangent of a given object.
......@@ -322,6 +342,7 @@ def arcsinh(x):
"""
return _math_helper(x, np.arcsinh)
def arctan(x):
"""
Returns the arctan of a given object.
......@@ -350,6 +371,7 @@ def arctan(x):
"""
return _math_helper(x, np.arctan)
def arctanh(x):
"""
Returns the hyperbolic arc tangent of a given object.
......@@ -378,6 +400,7 @@ def arctanh(x):
"""
return _math_helper(x, np.arctanh)
def sqrt(x):
"""
Returns the square root of a given object.
......@@ -402,6 +425,7 @@ def sqrt(x):
"""
return _math_helper(x, np.sqrt)
def exp(x):
"""
Returns the exponential of a given object.
......@@ -430,7 +454,8 @@ def exp(x):
"""
return _math_helper(x, np.exp)
def log(x,base=None):
def log(x, base=None):
"""
Returns the logarithm with respect to a specified base.
......@@ -462,11 +487,12 @@ def log(x,base=None):
return _math_helper(x, np.log)
base = np.array(base)
if(np.all(base>0)):
if np.all(base > 0):
return _math_helper(x, np.log)/np.log(base)
else:
raise ValueError(about._errors.cstring("ERROR: invalid input basis."))
def conjugate(x):
"""
Computes the complex conjugate of a given object.
......@@ -482,9 +508,3 @@ def conjugate(x):
The complex conjugated object.
"""
return _math_helper(x, np.conjugate)
##---------------------------------
\ No newline at end of file
......@@ -71,7 +71,7 @@ def _hermitianize_inverter(x):
return y
def direct_dot(x, y):
def direct_vdot(x, y):
# the input could be fields. Try to extract the data
try:
x = x.get_val()
......
......@@ -42,6 +42,9 @@ class los_response(operator):
starts, ends, sigmas_low,
sigmas_up, zero_point)
self._local_shape = self._init_local_shape()
self._set_extractor_d2o()
self.local_weights_and_indices = self._compute_weights_and_indices()
self.number_of_los = len(self.sigmas_low)
......@@ -212,7 +215,7 @@ class los_response(operator):
"ERROR: The space's datamodel is not supported:" +
str(self.domain.datamodel)))
def _get_local_shape(self):
def _init_local_shape(self):
if self.domain.datamodel == 'np':
return self.domain.get_shape()
elif self.domain.datamodel in STRATEGIES['not']:
......@@ -225,6 +228,9 @@ class los_response(operator):
skip_parsing=True)
return dummy_d2o.distributor.local_shape
def _get_local_shape(self):
return self._local_shape
def _compute_weights_and_indices(self):
# compute the local pixel coordinates for the starts and ends
localized_pixel_starts = self._convert_physical_to_indices(self.starts)
......@@ -258,11 +264,7 @@ class los_response(operator):
return local_indices_and_weights_list
def _multiply(self, input_field):
# extract the local data array from the input field
try:
local_input_data = input_field.val.data
except AttributeError:
local_input_data = input_field.val
local_input_data = self._multiply_preprocessing(input_field)
local_result = np.zeros(self.number_of_los, dtype=self.target.dtype)
......@@ -272,19 +274,33 @@ class los_response(operator):
local_result[los_index] += \
np.sum(local_input_data[indices]*weights)
if self.domain.datamodel == 'np':
global_result = local_result
elif self.domain.datamodel is STRATEGIES['not']:
global_result = local_result
if self.domain.datamodel in STRATEGIES['slicing']:
global_result = np.empty_like(local_result)
self.domain.comm.Allreduce(local_result, global_result, op=MPI.SUM)
global_result = self._multiply_postprocessing(local_result)
result_field = field(self.target,
val=global_result,
codomain=self.cotarget)
return result_field
def _multiply_preprocessing(self, input_field):
if self.domain.datamodel == 'np':
local_input_data = input_field.val
elif self.domain.datamodel in STRATEGIES['not']:
local_input_data = input_field.val.data
elif self.domain.datamodel in STRATEGIES['slicing']:
extractor = self._extractor_d2o.distributor.extract_local_data
local_input_data = extractor(input_field.val)
return local_input_data
def _multiply_postprocessing(self, local_result):
if self.domain.datamodel == 'np':
global_result = local_result
elif self.domain.datamodel in STRATEGIES['not']:
global_result = local_result
elif self.domain.datamodel in STRATEGIES['slicing']:
global_result = np.empty_like(local_result)
self.domain.comm.Allreduce(local_result, global_result, op=MPI.SUM)
return global_result
def _adjoint_multiply(self, input_field):
# get the full data as np.ndarray from the input field
try:
......@@ -321,14 +337,53 @@ class los_response(operator):
return result_field
def _improve_slicing(self):
if self.domain.datamodel not in STRATEGIES['slicing']:
raise ValueError(about._errors.cstring(
"ERROR: distribution strategy of domain is not a " +
"slicing one."))
comm = self.domain.comm
local_weight = np.sum(
[len(los[2]) for los in self.local_weights_and_indices])
local_length = self._get_local_shape()[0]
weights = comm.allgather(local_weight)
lengths = comm.allgather(local_length)
optimized_lengths = self._length_equilibrator(lengths, weights)
new_local_shape = list(self._local_shape)
new_local_shape[0] = optimized_lengths[comm.rank]
self._local_shape = tuple(new_local_shape)
self._set_extractor_d2o()
self.local_weights_and_indices = self._compute_weights_and_indices()
def _length_equilibrator(self, lengths, weights):
lengths = np.array(lengths, dtype=np.float)
weights = np.array(weights, dtype=np.float)
number_of_nodes = len(lengths)
cs_lengths = np.append(0, np.cumsum(lengths))
cs_weights = np.append(0, np.cumsum(weights))
total_weight = cs_weights[-1]
equiweights = np.linspace(0,
total_weight,
number_of_nodes+1)
equiweight_distances = np.interp(equiweights,
cs_weights,
cs_lengths)
equiweight_lengths = np.diff(np.floor(equiweight_distances))
return equiweight_lengths
def _set_extractor_d2o(self):
if self.domain.datamodel in STRATEGIES['slicing']:
temp_d2o = self.domain.cast()
extractor = temp_d2o.copy_empty(local_shape=self._local_shape,
distribution_strategy='freeform')
self._extractor_d2o = extractor
else:
self._extractor_d2o = None
This diff is collapsed.
......@@ -24,7 +24,7 @@ from __future__ import division
from nifty.keepers import about
from nifty.nifty_core import space, \
field
from nifty.nifty_utilities import direct_dot
from nifty.nifty_utilities import direct_vdot
......@@ -468,7 +468,7 @@ class trace_prober(_specialized_prober):
**kwargs)
def _probing_function(self, probe):
return direct_dot(probe.conjugate(), self.operator.times(probe))
return direct_vdot(probe.conjugate(), self.operator.times(probe))
class inverse_trace_prober(_specialized_prober):
......@@ -478,7 +478,7 @@ class inverse_trace_prober(_specialized_prober):
**kwargs)
def _probing_function(self, probe):
return direct_dot(probe.conjugate(),
return direct_vdot(probe.conjugate(),
self.operator.inverse_times(probe))
......
......@@ -874,10 +874,9 @@ class Test_list_get_set_data(unittest.TestCase):
(a, obj) = generate_data(global_shape, dtype,
distribution_strategy_1)
w = np.where(a > 30)
w = np.where(a > 28)
p = obj.copy(distribution_strategy=distribution_strategy_2)
wo = (p > 30).where()
wo = (p > 28).where()
assert_equal(obj[w].get_full_data(), a[w])
assert_equal(obj[wo].get_full_data(), a[w])
......@@ -903,7 +902,7 @@ class Test_list_get_set_data(unittest.TestCase):
assert_equal(obj[wo].get_full_data(), a[w])
##############################################################################
#############################################################################
@parameterized.expand(
itertools.product(
......@@ -1601,22 +1600,23 @@ class Test_comparisons(unittest.TestCase):
class Test_special_methods(unittest.TestCase):
@parameterized.expand(all_distribution_strategies,
@parameterized.expand(
itertools.product(all_distribution_strategies,
all_distribution_strategies),
testcase_func_name=custom_name_func)
def test_bincount(self, distribution_strategy):
global_shape = (80,)
def test_bincount(self, distribution_strategy_1, distribution_strategy_2):
global_shape = (10,)
dtype = np.dtype('int')
dtype_weights = np.dtype('float')
(a, obj) = generate_data(global_shape, dtype,
distribution_strategy)
distribution_strategy_1)
a = abs(a)
obj = abs(obj)
(b, p) = generate_data(global_shape, dtype_weights,
distribution_strategy)
distribution_strategy_2)
b **= 2
p **= 2
assert_equal(obj.bincount(weights=p),
np.bincount(a, weights=b))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment