Commit 7f240569 authored by Ultima's avatar Ultima
Browse files

Created a bunch of tests for distributed_data_object.

-> Many minor and major bugfixes
-> Reworked indexing (now works with negative step sizes, too)

Updated the naming of domain, codomain, target and cotarget in the spaces, fields and operators.
-> The propagator_operator is now 3 times faster
parent e94b3f41
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
from __future__ import division from __future__ import division
import matplotlib as mpl import matplotlib as mpl
mpl.use('Agg') #mpl.use('Agg')
from nifty_about import about from nifty_about import about
from nifty_cmaps import ncmap from nifty_cmaps import ncmap
......
...@@ -83,7 +83,8 @@ class problem(object): ...@@ -83,7 +83,8 @@ class problem(object):
self.d = self.R(self.s) + n self.d = self.R(self.s) + n
## define information source ## define information source
self.j = self.R.adjoint_times(self.N.inverse_times(self.d), target=self.k) #self.j = self.R.adjoint_times(self.N.inverse_times(self.d), target=self.k)
self.j = self.R.adjoint_times(self.N.inverse_times(self.d))
## define information propagator ## define information propagator
self.D = propagator_operator(S=self.S, N=self.N, R=self.R) self.D = propagator_operator(S=self.S, N=self.N, R=self.R)
......
...@@ -42,6 +42,7 @@ x_space = rg_space([1280, 1280], datamodel = 'd2o') ...@@ -42,6 +42,7 @@ x_space = rg_space([1280, 1280], datamodel = 'd2o')
#x_space = gl_space(96) #x_space = gl_space(96)
k_space = x_space.get_codomain() # get conjugate space k_space = x_space.get_codomain() # get conjugate space
y_space = point_space(1280*1280, datamodel='d2o')
# some power spectrum # some power spectrum
power = (lambda k: 42 / (k + 1) ** 3) power = (lambda k: 42 / (k + 1) ** 3)
...@@ -49,7 +50,7 @@ power = (lambda k: 42 / (k + 1) ** 3) ...@@ -49,7 +50,7 @@ power = (lambda k: 42 / (k + 1) ** 3)
S = power_operator(k_space, spec=power) # define signal covariance S = power_operator(k_space, spec=power) # define signal covariance
s = S.get_random_field(domain=x_space) # generate signal s = S.get_random_field(domain=x_space) # generate signal
R = response_operator(x_space, sigma=0.0, mask=1.0, assign=None) # define response R = response_operator(x_space, sigma=0.0, mask=1.0, assign=None, target = y_space) # define response
d_space = R.target # get data space d_space = R.target # get data space
# some noise variance; e.g., signal-to-noise ratio of 1 # some noise variance; e.g., signal-to-noise ratio of 1
...@@ -62,10 +63,10 @@ d = R(s) + n # compute data ...@@ -62,10 +63,10 @@ d = R(s) + n # compute data
j = R.adjoint_times(N.inverse_times(d)) # define information source j = R.adjoint_times(N.inverse_times(d)) # define information source
D = propagator_operator(S=S, N=N, R=R) # define information propagator D = propagator_operator(S=S, N=N, R=R) # define information propagator
m = D(j, W=S, tol=1E-1, note=True) # reconstruct map m = D(j, W=S, tol=1E-2, note=True) # reconstruct map
#s.plot(title="signal", save = 'plot_s.png') # plot signal s.plot(title="signal", save = 'plot_s.png') # plot signal
#d_ = field(x_space, val=d.val, target=k_space) d_ = field(x_space, val=d.val, target=k_space)
#d_.plot(title="data", vmin=s.min(), vmax=s.max(), save = 'plot_d.png') # plot data d_.plot(title="data", vmin=s.min(), vmax=s.max(), save = 'plot_d.png') # plot data
#m.plot(title="reconstructed map", vmin=s.min(), vmax=s.max(), save = 'plot_m.png') # plot map m.plot(title="reconstructed map", vmin=s.min(), vmax=s.max(), save = 'plot_m.png') # plot map
...@@ -128,7 +128,7 @@ class lm_space(point_space): ...@@ -128,7 +128,7 @@ class lm_space(point_space):
vol : numpy.ndarray vol : numpy.ndarray
Pixel volume of the :py:class:`lm_space`, which is always 1. Pixel volume of the :py:class:`lm_space`, which is always 1.
""" """
def __init__(self, lmax, mmax=None, datatype=None): def __init__(self, lmax, mmax=None, datatype=None, datamodel = 'np'):
""" """
Sets the attributes for an lm_space class instance. Sets the attributes for an lm_space class instance.
...@@ -519,6 +519,9 @@ class lm_space(point_space): ...@@ -519,6 +519,9 @@ class lm_space(point_space):
Compatible codomains are instances of :py:class:`lm_space`, Compatible codomains are instances of :py:class:`lm_space`,
:py:class:`gl_space`, and :py:class:`hp_space`. :py:class:`gl_space`, and :py:class:`hp_space`.
""" """
if codomain is None:
return False
if(not isinstance(codomain,space)): if(not isinstance(codomain,space)):
raise TypeError(about._errors.cstring("ERROR: invalid input.")) raise TypeError(about._errors.cstring("ERROR: invalid input."))
...@@ -669,7 +672,7 @@ class lm_space(point_space): ...@@ -669,7 +672,7 @@ class lm_space(point_space):
x : numpy.ndarray x : numpy.ndarray
Array to be transformed. Array to be transformed.
codomain : nifty.space, *optional* codomain : nifty.space, *optional*
Target space to which the transformation shall map codomain space to which the transformation shall map
(default: self). (default: self).
Returns Returns
...@@ -1290,6 +1293,9 @@ class gl_space(point_space): ...@@ -1290,6 +1293,9 @@ class gl_space(point_space):
Compatible codomains are instances of :py:class:`gl_space` and Compatible codomains are instances of :py:class:`gl_space` and
:py:class:`lm_space`. :py:class:`lm_space`.
""" """
if codomain is None:
return False
if(not isinstance(codomain,space)): if(not isinstance(codomain,space)):
raise TypeError(about._errors.cstring("ERROR: invalid input.")) raise TypeError(about._errors.cstring("ERROR: invalid input."))
...@@ -1398,7 +1404,7 @@ class gl_space(point_space): ...@@ -1398,7 +1404,7 @@ class gl_space(point_space):
x : numpy.ndarray x : numpy.ndarray
Array to be transformed. Array to be transformed.
codomain : nifty.space, *optional* codomain : nifty.space, *optional*
Target space to which the transformation shall map codomain space to which the transformation shall map
(default: self). (default: self).
Returns Returns
...@@ -1953,6 +1959,9 @@ class hp_space(point_space): ...@@ -1953,6 +1959,9 @@ class hp_space(point_space):
Compatible codomains are instances of :py:class:`hp_space` and Compatible codomains are instances of :py:class:`hp_space` and
:py:class:`lm_space`. :py:class:`lm_space`.
""" """
if codomain is None:
return False
if(not isinstance(codomain,space)): if(not isinstance(codomain,space)):
raise TypeError(about._errors.cstring("ERROR: invalid input.")) raise TypeError(about._errors.cstring("ERROR: invalid input."))
...@@ -2025,7 +2034,7 @@ class hp_space(point_space): ...@@ -2025,7 +2034,7 @@ class hp_space(point_space):
x : numpy.ndarray x : numpy.ndarray
Array to be transformed. Array to be transformed.
codomain : nifty.space, *optional* codomain : nifty.space, *optional*
Target space to which the transformation shall map codomain space to which the transformation shall map
(default: self). (default: self).
Returns Returns
......
...@@ -271,7 +271,8 @@ class notification(switch): ...@@ -271,7 +271,8 @@ class notification(switch):
String augmented with a color code. String augmented with a color code.
""" """
return self.ccode+str(self._get_caller())+':\n'+str(subject)+self._code return self.ccode + str(self._get_caller()) + ':\n' + \
str(subject) + self._code + '\n'
def cflush(self,subject): def cflush(self,subject):
""" """
......
...@@ -729,7 +729,7 @@ class space(object): ...@@ -729,7 +729,7 @@ class space(object):
x : numpy.ndarray x : numpy.ndarray
Array to be transformed. Array to be transformed.
codomain : nifty.space, *optional* codomain : nifty.space, *optional*
Target space to which the transformation shall map codomain space to which the transformation shall map
(default: self). (default: self).
Returns Returns
...@@ -1795,7 +1795,7 @@ class point_space(space): ...@@ -1795,7 +1795,7 @@ class point_space(space):
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def check_codomain(self,codomain): def check_codomain(self, codomain):
""" """
Checks whether a given codomain is compatible to the space or not. Checks whether a given codomain is compatible to the space or not.
...@@ -1809,6 +1809,9 @@ class point_space(space): ...@@ -1809,6 +1809,9 @@ class point_space(space):
check : bool check : bool
Whether or not the given codomain is compatible to the space. Whether or not the given codomain is compatible to the space.
""" """
if codomain is None:
return False
if not isinstance(codomain, space): if not isinstance(codomain, space):
raise TypeError(about._errors.cstring( raise TypeError(about._errors.cstring(
"ERROR: invalid input. The given input is no nifty space.")) "ERROR: invalid input. The given input is no nifty space."))
...@@ -1965,7 +1968,7 @@ class point_space(space): ...@@ -1965,7 +1968,7 @@ class point_space(space):
x : numpy.ndarray x : numpy.ndarray
Array to be transformed. Array to be transformed.
codomain : nifty.space, *optional* codomain : nifty.space, *optional*
Target space to which the transformation shall map codomain space to which the transformation shall map
(default: self). (default: self).
Returns Returns
...@@ -2008,8 +2011,10 @@ class point_space(space): ...@@ -2008,8 +2011,10 @@ class point_space(space):
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def calc_real_Q(self, x): def calc_real_Q(self, x):
try: try:
return x.is_completely_real() return x.isreal().all()
except(AttributeError): except(AttributeError):
return np.all(np.isreal(x)) return np.all(np.isreal(x))
...@@ -2641,7 +2646,7 @@ class nested_space(space): ...@@ -2641,7 +2646,7 @@ class nested_space(space):
Other parameters Other parameters
---------------- ----------------
target : nifty.space, *optional* codomain : nifty.space, *optional*
Space in which the transform of the output field lives Space in which the transform of the output field lives
(default: None). (default: None).
...@@ -2702,7 +2707,7 @@ class nested_space(space): ...@@ -2702,7 +2707,7 @@ class nested_space(space):
x : numpy.ndarray x : numpy.ndarray
Array to be transformed. Array to be transformed.
codomain : nifty.space, *optional* codomain : nifty.space, *optional*
Target space to which the transformation shall map codomain space to which the transformation shall map
(default: self). (default: self).
Returns Returns
...@@ -2870,7 +2875,7 @@ class field(object): ...@@ -2870,7 +2875,7 @@ class field(object):
space defined in domain or to be drawn from a random distribution space defined in domain or to be drawn from a random distribution
controlled by kwargs. controlled by kwargs.
target : space, *optional* codomain : space, *optional*
The space wherein the operator output lives (default: domain). The space wherein the operator output lives (default: domain).
...@@ -2929,11 +2934,11 @@ class field(object): ...@@ -2929,11 +2934,11 @@ class field(object):
space defined in domain or to be drawn from a random distribution space defined in domain or to be drawn from a random distribution
controlled by the keyword arguments. controlled by the keyword arguments.
target : space, *optional* codomain : space, *optional*
The space wherein the operator output lives (default: domain). The space wherein the operator output lives (default: domain).
""" """
def __init__(self, domain, val=None, target=None, **kwargs): def __init__(self, domain, val=None, codomain=None, **kwargs):
""" """
Sets the attributes for a field class instance. Sets the attributes for a field class instance.
...@@ -2948,7 +2953,7 @@ class field(object): ...@@ -2948,7 +2953,7 @@ class field(object):
space defined in domain or to be drawn from a random distribution space defined in domain or to be drawn from a random distribution
controlled by the keyword arguments. controlled by the keyword arguments.
target : space, *optional* codomain : space, *optional*
The space wherein the operator output lives (default: domain). The space wherein the operator output lives (default: domain).
Returns Returns
...@@ -2961,17 +2966,17 @@ class field(object): ...@@ -2961,17 +2966,17 @@ class field(object):
raise TypeError(about._errors.cstring("ERROR: invalid input.")) raise TypeError(about._errors.cstring("ERROR: invalid input."))
self.domain = domain self.domain = domain
## check codomain ## check codomain
if target is None: if codomain is None:
target = domain.get_codomain() codomain = domain.get_codomain()
else: else:
assert(self.domain.check_codomain(target)) assert(self.domain.check_codomain(codomain))
self.target = target self.codomain = codomain
if val == None: if val == None:
if kwargs == {}: if kwargs == {}:
self.val = self.domain.cast(0.) self.val = self.domain.cast(0.)
else: else:
self.val = self.domain.get_random_values(codomain=self.target, self.val = self.domain.get_random_values(codomain=self.codomain,
**kwargs) **kwargs)
else: else:
self.val = val self.val = val
...@@ -2986,18 +2991,18 @@ class field(object): ...@@ -2986,18 +2991,18 @@ class field(object):
self.__val = self.domain.cast(x) self.__val = self.domain.cast(x)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def copy(self, domain=None, target=None): def copy(self, domain=None, codomain=None):
new_field = self.copy_empty(domain=domain, target=target) new_field = self.copy_empty(domain=domain, codomain=codomain)
new_field.val = new_field.domain.cast(self.val.copy()) new_field.val = new_field.domain.cast(self.val.copy())
return new_field return new_field
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def copy_empty(self, domain=None, target=None, **kwargs): def copy_empty(self, domain=None, codomain=None, **kwargs):
if domain == None: if domain == None:
domain = self.domain domain = self.domain
if target == None: if codomain == None:
target = self.target codomain = self.codomain
new_field = field(domain=domain, target=target, **kwargs) new_field = field(domain=domain, codomain=codomain, **kwargs)
return new_field return new_field
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
...@@ -3023,7 +3028,7 @@ class field(object): ...@@ -3023,7 +3028,7 @@ class field(object):
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def cast_domain(self, newdomain, new_target=None, force=True): def cast_domain(self, newdomain, new_codomain=None, force=True):
""" """
Casts the domain of the field. Casts the domain of the field.
...@@ -3032,9 +3037,9 @@ class field(object): ...@@ -3032,9 +3037,9 @@ class field(object):
newdomain : space newdomain : space
New space wherein the field should live. New space wherein the field should live.
new_target : space, *optional* new_codomain : space, *optional*
Space wherein the transform of the field should live. Space wherein the transform of the field should live.
When not given, target will automatically be the codomain When not given, codomain will automatically be the codomain
of the newly casted domain (default=None). of the newly casted domain (default=None).
force : bool, *optional* force : bool, *optional*
...@@ -3075,20 +3080,20 @@ class field(object): ...@@ -3075,20 +3080,20 @@ class field(object):
## Use the casting of the new domain in order to make the old data fit. ## Use the casting of the new domain in order to make the old data fit.
self.set_val(new_val = self.val) self.set_val(new_val = self.val)
## set the target ## set the codomain
if new_target == None: if new_codomain == None:
if not self.domain.check_codomain(self.target): if not self.domain.check_codomain(self.codomain):
if(force): if(force):
about.infos.cprint("INFO: codomain set to default.") about.infos.cprint("INFO: codomain set to default.")
else: else:
about.warnings.cprint("WARNING: codomain set to default.") about.warnings.cprint("WARNING: codomain set to default.")
self.set_target(new_target = self.domain.get_codomain()) self.set_codomain(new_codomain = self.domain.get_codomain())
else: else:
self.set_target(new_target = new_target, force = force) self.set_codomain(new_codomain = new_codomain, force = force)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def set_val(self, new_val): def set_val(self, new_val = None):
""" """
Resets the field values. Resets the field values.
...@@ -3098,7 +3103,8 @@ class field(object): ...@@ -3098,7 +3103,8 @@ class field(object):
New field values either as a constant or an arbitrary array. New field values either as a constant or an arbitrary array.
""" """
self.val = new_val if new_val is not None:
self.val = new_val
return self.val return self.val
def get_val(self): def get_val(self):
...@@ -3107,31 +3113,31 @@ class field(object): ...@@ -3107,31 +3113,31 @@ class field(object):
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def set_domain(self, new_domain=None, force=False): def set_domain(self, new_domain=None, force=False):
if new_domain is None: if new_domain is None:
new_domain = self.target.get_codomain() new_domain = self.codomain.get_codomain()
elif force == False: elif force == False:
assert(self.target.check_codomain(new_domain)) assert(self.codomain.check_codomain(new_domain))
self.domain = new_domain self.domain = new_domain
return self.domain return self.domain
def set_target(self, new_target=None, force=False): def set_codomain(self, new_codomain=None, force=False):
""" """
Resets the codomain of the field. Resets the codomain of the field.
Parameters Parameters
---------- ----------
new_target : space new_codomain : space
The new space wherein the transform of the field should live. The new space wherein the transform of the field should live.
(default=None). (default=None).
""" """
## check codomain ## check codomain
if new_target is None: if new_codomain is None:
new_target = self.domain.get_codomain() new_codomain = self.domain.get_codomain()
elif force == False: elif force == False:
assert(self.domain.check_codomain(new_target)) assert(self.domain.check_codomain(new_codomain))
self.target = new_target self.codomain = new_codomain
return self.target return self.codomain
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
...@@ -3256,7 +3262,7 @@ class field(object): ...@@ -3256,7 +3262,7 @@ class field(object):
Other Parameters Other Parameters
---------------- ----------------
target : space, *optional* codomain : space, *optional*
space wherein the transform of the output field should live space wherein the transform of the output field should live
(default: None). (default: None).
...@@ -3296,7 +3302,7 @@ class field(object): ...@@ -3296,7 +3302,7 @@ class field(object):
return self.pseudo_dot(x=x.val,**kwargs) return self.pseudo_dot(x=x.val,**kwargs)
except(TypeError,ValueError): except(TypeError,ValueError):
try: try:
return self.pseudo_dot(x=x.transform(target=x.target,overwrite=False).val,**kwargs) return self.pseudo_dot(x=x.transform(codomain=x.codomain,overwrite=False).val,**kwargs)
except(TypeError,ValueError): except(TypeError,ValueError):
raise ValueError(about._errors.cstring("ERROR: incompatible domains.")) raise ValueError(about._errors.cstring("ERROR: incompatible domains."))
## pseudo inner product (calc_pseudo_dot handles weights) ## pseudo inner product (calc_pseudo_dot handles weights)
...@@ -3331,7 +3337,7 @@ class field(object): ...@@ -3331,7 +3337,7 @@ class field(object):
Other Parameters Other Parameters
---------------- ----------------
target : space, *optional* codomain : space, *optional*
space wherein the transform of the output field should live space wherein the transform of the output field should live
(default: None). (default: None).
...@@ -3370,15 +3376,15 @@ class field(object): ...@@ -3370,15 +3376,15 @@ class field(object):
return work_field return work_field
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def transform(self, target=None, overwrite=False, **kwargs): def transform(self, codomain=None, overwrite=False, **kwargs):
""" """
Computes the transform of the field using the appropriate conjugate Computes the transform of the field using the appropriate conjugate
transformation. transformation.
Parameters Parameters
---------- ----------
target : space, *optional* codomain : space, *optional*
Domain of the transform of the field (default:self.target) Domain of the transform of the field (default:self.codomain)
overwrite : bool, *optional* overwrite : bool, *optional*
Whether to overwrite the field or not (default: False). Whether to overwrite the field or not (default: False).
...@@ -3395,22 +3401,22 @@ class field(object): ...@@ -3395,22 +3401,22 @@ class field(object):
Otherwise, nothing is returned. Otherwise, nothing is returned.
""" """
if(target is None): if(codomain is None):
target = self.target codomain = self.codomain
else: else:
assert(self.domain.check_codomain(target)) assert(self.domain.check_codomain(codomain))
new_val = self.domain.calc_transform(self.val, new_val = self.domain.calc_transform(self.val,
codomain=target, codomain=codomain,
**kwargs) **kwargs)