Commit 7f240569 authored by Ultima's avatar Ultima

Created a bunch of tests for distributed_data_object.

-> Many minor and major bugfixes
-> Reworked indexing (now works with negative step sizes, too)

Updated the naming of domain, codomain, target and cotarget in the spaces, fields and operators.
-> The propagator_operator is now 3 times faster
parent e94b3f41
......@@ -22,7 +22,7 @@
from __future__ import division
import matplotlib as mpl
mpl.use('Agg')
#mpl.use('Agg')
from nifty_about import about
from nifty_cmaps import ncmap
......
......@@ -83,7 +83,8 @@ class problem(object):
self.d = self.R(self.s) + n
## define information source
self.j = self.R.adjoint_times(self.N.inverse_times(self.d), target=self.k)
#self.j = self.R.adjoint_times(self.N.inverse_times(self.d), target=self.k)
self.j = self.R.adjoint_times(self.N.inverse_times(self.d))
## define information propagator
self.D = propagator_operator(S=self.S, N=self.N, R=self.R)
......
......@@ -42,6 +42,7 @@ x_space = rg_space([1280, 1280], datamodel = 'd2o')
#x_space = gl_space(96)
k_space = x_space.get_codomain() # get conjugate space
y_space = point_space(1280*1280, datamodel='d2o')
# some power spectrum
power = (lambda k: 42 / (k + 1) ** 3)
......@@ -49,7 +50,7 @@ power = (lambda k: 42 / (k + 1) ** 3)
S = power_operator(k_space, spec=power) # define signal covariance
s = S.get_random_field(domain=x_space) # generate signal
R = response_operator(x_space, sigma=0.0, mask=1.0, assign=None) # define response
R = response_operator(x_space, sigma=0.0, mask=1.0, assign=None, target = y_space) # define response
d_space = R.target # get data space
# some noise variance; e.g., signal-to-noise ratio of 1
......@@ -62,10 +63,10 @@ d = R(s) + n # compute data
j = R.adjoint_times(N.inverse_times(d)) # define information source
D = propagator_operator(S=S, N=N, R=R) # define information propagator
m = D(j, W=S, tol=1E-1, note=True) # reconstruct map
m = D(j, W=S, tol=1E-2, note=True) # reconstruct map
#s.plot(title="signal", save = 'plot_s.png') # plot signal
#d_ = field(x_space, val=d.val, target=k_space)
#d_.plot(title="data", vmin=s.min(), vmax=s.max(), save = 'plot_d.png') # plot data
#m.plot(title="reconstructed map", vmin=s.min(), vmax=s.max(), save = 'plot_m.png') # plot map
s.plot(title="signal", save = 'plot_s.png') # plot signal
d_ = field(x_space, val=d.val, target=k_space)
d_.plot(title="data", vmin=s.min(), vmax=s.max(), save = 'plot_d.png') # plot data
m.plot(title="reconstructed map", vmin=s.min(), vmax=s.max(), save = 'plot_m.png') # plot map
......@@ -128,7 +128,7 @@ class lm_space(point_space):
vol : numpy.ndarray
Pixel volume of the :py:class:`lm_space`, which is always 1.
"""
def __init__(self, lmax, mmax=None, datatype=None):
def __init__(self, lmax, mmax=None, datatype=None, datamodel = 'np'):
"""
Sets the attributes for an lm_space class instance.
......@@ -519,6 +519,9 @@ class lm_space(point_space):
Compatible codomains are instances of :py:class:`lm_space`,
:py:class:`gl_space`, and :py:class:`hp_space`.
"""
if codomain is None:
return False
if(not isinstance(codomain,space)):
raise TypeError(about._errors.cstring("ERROR: invalid input."))
......@@ -669,7 +672,7 @@ class lm_space(point_space):
x : numpy.ndarray
Array to be transformed.
codomain : nifty.space, *optional*
Target space to which the transformation shall map
codomain space to which the transformation shall map
(default: self).
Returns
......@@ -1290,6 +1293,9 @@ class gl_space(point_space):
Compatible codomains are instances of :py:class:`gl_space` and
:py:class:`lm_space`.
"""
if codomain is None:
return False
if(not isinstance(codomain,space)):
raise TypeError(about._errors.cstring("ERROR: invalid input."))
......@@ -1398,7 +1404,7 @@ class gl_space(point_space):
x : numpy.ndarray
Array to be transformed.
codomain : nifty.space, *optional*
Target space to which the transformation shall map
codomain space to which the transformation shall map
(default: self).
Returns
......@@ -1953,6 +1959,9 @@ class hp_space(point_space):
Compatible codomains are instances of :py:class:`hp_space` and
:py:class:`lm_space`.
"""
if codomain is None:
return False
if(not isinstance(codomain,space)):
raise TypeError(about._errors.cstring("ERROR: invalid input."))
......@@ -2025,7 +2034,7 @@ class hp_space(point_space):
x : numpy.ndarray
Array to be transformed.
codomain : nifty.space, *optional*
Target space to which the transformation shall map
codomain space to which the transformation shall map
(default: self).
Returns
......
......@@ -271,7 +271,8 @@ class notification(switch):
String augmented with a color code.
"""
return self.ccode+str(self._get_caller())+':\n'+str(subject)+self._code
return self.ccode + str(self._get_caller()) + ':\n' + \
str(subject) + self._code + '\n'
def cflush(self,subject):
"""
......
......@@ -729,7 +729,7 @@ class space(object):
x : numpy.ndarray
Array to be transformed.
codomain : nifty.space, *optional*
Target space to which the transformation shall map
codomain space to which the transformation shall map
(default: self).
Returns
......@@ -1795,7 +1795,7 @@ class point_space(space):
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def check_codomain(self,codomain):
def check_codomain(self, codomain):
"""
Checks whether a given codomain is compatible to the space or not.
......@@ -1809,6 +1809,9 @@ class point_space(space):
check : bool
Whether or not the given codomain is compatible to the space.
"""
if codomain is None:
return False
if not isinstance(codomain, space):
raise TypeError(about._errors.cstring(
"ERROR: invalid input. The given input is no nifty space."))
......@@ -1965,7 +1968,7 @@ class point_space(space):
x : numpy.ndarray
Array to be transformed.
codomain : nifty.space, *optional*
Target space to which the transformation shall map
codomain space to which the transformation shall map
(default: self).
Returns
......@@ -2008,8 +2011,10 @@ class point_space(space):
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def calc_real_Q(self, x):
try:
return x.is_completely_real()
return x.isreal().all()
except(AttributeError):
return np.all(np.isreal(x))
......@@ -2641,7 +2646,7 @@ class nested_space(space):
Other parameters
----------------
target : nifty.space, *optional*
codomain : nifty.space, *optional*
Space in which the transform of the output field lives
(default: None).
......@@ -2702,7 +2707,7 @@ class nested_space(space):
x : numpy.ndarray
Array to be transformed.
codomain : nifty.space, *optional*
Target space to which the transformation shall map
codomain space to which the transformation shall map
(default: self).
Returns
......@@ -2870,7 +2875,7 @@ class field(object):
space defined in domain or to be drawn from a random distribution
controlled by kwargs.
target : space, *optional*
codomain : space, *optional*
The space wherein the operator output lives (default: domain).
......@@ -2929,11 +2934,11 @@ class field(object):
space defined in domain or to be drawn from a random distribution
controlled by the keyword arguments.
target : space, *optional*
codomain : space, *optional*
The space wherein the operator output lives (default: domain).
"""
def __init__(self, domain, val=None, target=None, **kwargs):
def __init__(self, domain, val=None, codomain=None, **kwargs):
"""
Sets the attributes for a field class instance.
......@@ -2948,7 +2953,7 @@ class field(object):
space defined in domain or to be drawn from a random distribution
controlled by the keyword arguments.
target : space, *optional*
codomain : space, *optional*
The space wherein the operator output lives (default: domain).
Returns
......@@ -2961,17 +2966,17 @@ class field(object):
raise TypeError(about._errors.cstring("ERROR: invalid input."))
self.domain = domain
## check codomain
if target is None:
target = domain.get_codomain()
if codomain is None:
codomain = domain.get_codomain()
else:
assert(self.domain.check_codomain(target))
self.target = target
assert(self.domain.check_codomain(codomain))
self.codomain = codomain
if val == None:
if kwargs == {}:
self.val = self.domain.cast(0.)
else:
self.val = self.domain.get_random_values(codomain=self.target,
self.val = self.domain.get_random_values(codomain=self.codomain,
**kwargs)
else:
self.val = val
......@@ -2986,18 +2991,18 @@ class field(object):
self.__val = self.domain.cast(x)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def copy(self, domain=None, target=None):
new_field = self.copy_empty(domain=domain, target=target)
def copy(self, domain=None, codomain=None):
new_field = self.copy_empty(domain=domain, codomain=codomain)
new_field.val = new_field.domain.cast(self.val.copy())
return new_field
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def copy_empty(self, domain=None, target=None, **kwargs):
def copy_empty(self, domain=None, codomain=None, **kwargs):
if domain == None:
domain = self.domain
if target == None:
target = self.target
new_field = field(domain=domain, target=target, **kwargs)
if codomain == None:
codomain = self.codomain
new_field = field(domain=domain, codomain=codomain, **kwargs)
return new_field
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......@@ -3023,7 +3028,7 @@ class field(object):
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def cast_domain(self, newdomain, new_target=None, force=True):
def cast_domain(self, newdomain, new_codomain=None, force=True):
"""
Casts the domain of the field.
......@@ -3032,9 +3037,9 @@ class field(object):
newdomain : space
New space wherein the field should live.
new_target : space, *optional*
new_codomain : space, *optional*
Space wherein the transform of the field should live.
When not given, target will automatically be the codomain
When not given, codomain will automatically be the codomain
of the newly casted domain (default=None).
force : bool, *optional*
......@@ -3075,20 +3080,20 @@ class field(object):
## Use the casting of the new domain in order to make the old data fit.
self.set_val(new_val = self.val)
## set the target
if new_target == None:
if not self.domain.check_codomain(self.target):
## set the codomain
if new_codomain == None:
if not self.domain.check_codomain(self.codomain):
if(force):
about.infos.cprint("INFO: codomain set to default.")
else:
about.warnings.cprint("WARNING: codomain set to default.")
self.set_target(new_target = self.domain.get_codomain())
self.set_codomain(new_codomain = self.domain.get_codomain())
else:
self.set_target(new_target = new_target, force = force)
self.set_codomain(new_codomain = new_codomain, force = force)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def set_val(self, new_val):
def set_val(self, new_val = None):
"""
Resets the field values.
......@@ -3098,7 +3103,8 @@ class field(object):
New field values either as a constant or an arbitrary array.
"""
self.val = new_val
if new_val is not None:
self.val = new_val
return self.val
def get_val(self):
......@@ -3107,31 +3113,31 @@ class field(object):
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def set_domain(self, new_domain=None, force=False):
if new_domain is None:
new_domain = self.target.get_codomain()
new_domain = self.codomain.get_codomain()
elif force == False:
assert(self.target.check_codomain(new_domain))
assert(self.codomain.check_codomain(new_domain))
self.domain = new_domain
return self.domain
def set_target(self, new_target=None, force=False):
def set_codomain(self, new_codomain=None, force=False):
"""
Resets the codomain of the field.
Parameters
----------
new_target : space
new_codomain : space
The new space wherein the transform of the field should live.
(default=None).
"""
## check codomain
if new_target is None:
new_target = self.domain.get_codomain()
if new_codomain is None:
new_codomain = self.domain.get_codomain()
elif force == False:
assert(self.domain.check_codomain(new_target))
self.target = new_target
return self.target
assert(self.domain.check_codomain(new_codomain))
self.codomain = new_codomain
return self.codomain
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......@@ -3256,7 +3262,7 @@ class field(object):
Other Parameters
----------------
target : space, *optional*
codomain : space, *optional*
space wherein the transform of the output field should live
(default: None).
......@@ -3296,7 +3302,7 @@ class field(object):
return self.pseudo_dot(x=x.val,**kwargs)
except(TypeError,ValueError):
try:
return self.pseudo_dot(x=x.transform(target=x.target,overwrite=False).val,**kwargs)
return self.pseudo_dot(x=x.transform(codomain=x.codomain,overwrite=False).val,**kwargs)
except(TypeError,ValueError):
raise ValueError(about._errors.cstring("ERROR: incompatible domains."))
## pseudo inner product (calc_pseudo_dot handles weights)
......@@ -3331,7 +3337,7 @@ class field(object):
Other Parameters
----------------
target : space, *optional*
codomain : space, *optional*
space wherein the transform of the output field should live
(default: None).
......@@ -3370,15 +3376,15 @@ class field(object):
return work_field
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def transform(self, target=None, overwrite=False, **kwargs):
def transform(self, codomain=None, overwrite=False, **kwargs):
"""
Computes the transform of the field using the appropriate conjugate
transformation.
Parameters
----------
target : space, *optional*
Domain of the transform of the field (default:self.target)
codomain : space, *optional*
Domain of the transform of the field (default:self.codomain)
overwrite : bool, *optional*
Whether to overwrite the field or not (default: False).
......@@ -3395,22 +3401,22 @@ class field(object):
Otherwise, nothing is returned.
"""
if(target is None):
target = self.target
if(codomain is None):
codomain = self.codomain
else:
assert(self.domain.check_codomain(target))
assert(self.domain.check_codomain(codomain))
new_val = self.domain.calc_transform(self.val,
codomain=target,
codomain=codomain,
**kwargs)
if overwrite == True:
return_field = self
return_field.set_target(new_target = self.domain, force = True)
return_field.set_domain(new_domain = target, force = True)
return_field.set_codomain(new_codomain = self.domain, force = True)
return_field.set_domain(new_domain = codomain, force = True)
else:
return_field = self.copy_empty(domain = self.target,
target = self.domain)
return_field = self.copy_empty(domain = self.codomain,
codomain = self.domain)
return_field.set_val(new_val = new_val)
return return_field
......@@ -3497,7 +3503,7 @@ class field(object):
about.warnings.cprint("WARNING: codomain was removed from kwargs.")
return self.domain.calc_power(self.get_val(),
codomain = self.target,
codomain = self.codomain,
**kwargs)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......@@ -3608,7 +3614,7 @@ class field(object):
about.warnings.cprint("WARNING: codomain was removed from kwargs.")
## draw/save the plot(s)
self.domain.get_plot(self.val, codomain=self.target, **kwargs)
self.domain.get_plot(self.val, codomain=self.codomain, **kwargs)
## restore the pylab interactiveness
pl.matplotlib.interactive(remember_interactive)
......@@ -3626,7 +3632,7 @@ class field(object):
"\n- val = [...]" + \
"\n - min.,max. = " + str(minmax) + \
"\n - mean = " + str(mean) + \
"\n- target = " + repr(self.target)
"\n- codomain = " + repr(self.codomain)
def __len__(self):
......
This diff is collapsed.
This diff is collapsed.
......@@ -348,8 +348,8 @@ class prober(object):
"""
return field(self.domain,
target=self.codomain,
random=self.random)
codomain = self.codomain,
random = self.random)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......
......@@ -631,7 +631,7 @@ class rg_space(point_space):
## Check hermitianity/reality
if self.paradict['complexity'] == 0:
if x.is_completely_real == False:
if x.iscomplex().any() == False:
about.warnings.cflush(\
"WARNING: Data is not completely real. Imaginary part "+\
"will be discarded!\n")
......@@ -737,10 +737,10 @@ class rg_space(point_space):
x = np.real(x)
elif self.paradict['complexity'] == 1:
if x.hermitian == False and about.hermitianize.status == True:
if about.hermitianize.status == True:
about.warnings.cflush(\
"WARNING: Data gets hermitianized. This operation is "+\
"rather expensive\n")
"rather expensive.\n")
#temp = x.copy_empty()
#temp.set_full_data(gp.nhermitianize_fast(x.get_full_data(),
# (False, )*len(x.shape)))
......@@ -879,9 +879,12 @@ class rg_space(point_space):
sample = distributed_data_object(global_shape=self.get_shape(),
dtype=self.datatype)
## Should the output be hermitianized?
hermitianizeQ = about.hermitianize.status and \
self.paradict['complexity']
## Should the output be hermitianized? This does not depend on the
## hermitianize boolean in about, as it would yield in wrong,
## not recoverable results
#hermitianizeQ = about.hermitianize.status and self.paradict['complexity']
hermitianizeQ = self.paradict['complexity']
## Case 1: uniform distribution over {-1,+1}/{1,i,-1,-i}
if arg[0] == 'pm1' and hermitianizeQ == False:
......@@ -1368,7 +1371,7 @@ class rg_space(point_space):
x : numpy.ndarray
Array to be transformed.
codomain : nifty.rg_space, *optional*
Target space to which the transformation shall map
codomain space to which the transformation shall map
(default: None).
Returns
......@@ -1403,7 +1406,7 @@ class rg_space(point_space):
Tx = codomain.calc_weight(Tx, power=-1)
## when the target space is purely real, the result of the
## when the codomain space is purely real, the result of the
## transformation must be corrected accordingly. Using the casting
## method of codomain is sufficient
## TODO: Let .transform yield the correct datatype
......@@ -2385,6 +2388,9 @@ class utilities(object):
## make the point inversions
flipped_x = utilities._hermitianize_inverter(x)
flipped_x = flipped_x.conjugate()
## check if x was already hermitian
if (x == flipped_x).all():
return x
## average x and flipped_x.
## Correct the variance by multiplying sqrt(0.5)
x = (x + flipped_x) * np.sqrt(0.5)
......@@ -2405,6 +2411,8 @@ class utilities(object):
return x
@staticmethod
def _hermitianize_inverter(x):
## calculate the number of dimensions the input array has
......
# -*- coding: utf-8 -*-
import unittest
import numpy as np
from nifty.nifty_mpi_data import distributed_data_object
found = {}
try:
from mpi4py import MPI
found[MPI] = True
except(ImportError):
# from mpi4py_dummy import MPI
found[MPI] = False
class TestDistributedData(unittest.TestCase):
def test_full_data_wr(self):
temp_data = np.array(np.arange(1000), dtype=int).reshape((200,5))
obj = distributed_data_object(global_data = temp_data)
np.testing.assert_equal(temp_data, obj.get_full_data())
if __name__ == '__main__':
unittest.main()
comm = MPI.COMM_WORLD
rank = comm.rank
if True:
#if rank == 0:
x = np.arange(10100000).reshape((101,100,1000)).astype(np.complex128)
#print x
#x = np.arange(3)
else:
x = None
obj = distributed_data_object(global_data=x, distribution_strategy='fftw')
#obj.load('myalias', 'mpitest.hdf5')
if MPI.COMM_WORLD.rank==0:
print ('rank', rank, vars(obj.distributor))
MPI.COMM_WORLD.Barrier()
#print ('rank', rank, vars(obj))
MPI.COMM_WORLD.Barrier()
temp_erg =obj.get_full_data(target_rank='all')
print ('rank', rank, 'full data', np.all(temp_erg == x), temp_erg.shape)
"""
MPI.COMM_WORLD.Barrier()
if rank == 0:
print ('erwuenscht', x[slice(1,10,2)])
sl = slice(1,2+rank,1)
print ('slice', rank, sl, obj[sl,2])
print obj[1:5:2,1:3]
if rank == 0:
sl = (slice(1,9,2), slice(1,5,2))
d = [[111, 222],[333,444],[111, 222],[333,444]]
else:
sl = (slice(6,10,2), slice(1,5,2))
d = [[555, 666],[777,888]]
obj[sl] = d
print obj.get_full_data()
"""