Commit f2541899 authored by Ultima's avatar Ultima
Browse files

Improved d2o indexing and field casting.

parent e8fe6581
......@@ -34,8 +34,7 @@ from nifty_mpi_data import distributed_data_object
from nifty_power import *
from nifty_random import random
from nifty_simple_math import *
from nifty_tools import conjugate_gradient,\
steepest_descent
from nifty_paradict import space_paradict,\
point_space_paradict,\
nested_space_paradict
......
......@@ -63,8 +63,8 @@ D = propagator_operator(S=S, N=N, R=R) # define inform
m = D(j, W=S, tol=1E-3, note=True) # reconstruct map
s.plot(title="signal") # plot signal
s.plot(title="signal", save = 'plot_s.png') # plot signal
d_ = field(x_space, val=d.val, target=k_space)
d_.plot(title="data", vmin=s.min(), vmax=s.max()) # plot data
m.plot(title="reconstructed map", vmin=s.min(), vmax=s.max()) # plot map
d_.plot(title="data", vmin=s.min(), vmax=s.max(), save = 'plot_d.png') # plot data
m.plot(title="reconstructed map", vmin=s.min(), vmax=s.max(), save = 'plot_m.png') # plot map
......@@ -1372,6 +1372,13 @@ class gl_space(point_space):
else:
return gl.weight(x,self.vol,p=np.float64(power),nlat=self.para[0],nlon=self.para[1],overwrite=False)
def get_weight(self, power = 1):
## TODO: Check if this function is compatible to the rest of the nifty code
## TODO: Can this be done more efficiently?
dummy = self.enforce_values(1)
weighted_dummy = self.calc_weight(dummy, power = power)
return weighted_dummy/dummy
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def calc_transform(self,x,codomain=None,**kwargs):
......
......@@ -682,6 +682,9 @@ class space(object):
"""
raise NotImplementedError(about._errors.cstring("ERROR: no generic instance method 'calc_weight'."))
def get_weight(self, power=1):
raise NotImplementedError(about._errors.cstring("ERROR: no generic instance method 'get_weight'."))
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def calc_dot(self,x,y):
......@@ -1608,8 +1611,11 @@ class point_space(space):
"""
x = self.enforce_shape(np.array(x,dtype=self.datatype))
## weight
return x*self.vol**power
return x*self.get_weight(power = power)
#return x*self.vol**power
def get_weight(self, power = 1):
return self.vol**power
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def calc_dot(self, x, y):
"""
......@@ -2291,7 +2297,11 @@ class nested_space(space):
"""
x = self.enforce_shape(np.array(x,dtype=self.datatype))
## weight
return x*self.get_meta_volume(total=False)**power
return x*self.get_weight(power = power)
def get_weight(self, power = 1):
return self.get_meta_volume(total=False)**power
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......@@ -2669,7 +2679,7 @@ class field(object):
if val == None:
if kwargs == {}:
self.val = self.domain.cast(0)
self.val = self.domain.cast(0.)
else:
self.val = self.domain.get_random_values(codomain=self.target,
**kwargs)
......@@ -3349,8 +3359,9 @@ class field(object):
temp = self
else:
temp = self.copy_empty()
data_object = self.domain.apply_scalar_function(self.val,\
function, inplace)
data_object = self.domain.apply_scalar_function(self.val,
function,
inplace)
temp.set_val(data_object)
return temp
......
......@@ -163,11 +163,11 @@ class distributed_data_object(object):
**kwargs)
return temp_d2o
def apply_scalar_function(self, function, inplace=False):
def apply_scalar_function(self, function, inplace=False, dtype=None):
if inplace == True:
temp = self
else:
temp = self.copy_empty()
temp = self.copy_empty(dtype=dtype)
try:
temp.data[:] = function(self.data)
......@@ -260,34 +260,54 @@ class distributed_data_object(object):
temp_d2o.set_local_data(data = self.get_local_data().__abs__())
return temp_d2o
def __builtin_helper__(self, operator, other):
def __builtin_helper__(self, operator, other, inplace=False):
## Case 1: other is not a scalar
if not (np.isscalar(other) or np.shape(other) == (1,)):
## if self.shape != other.shape:
## raise AttributeError(about._errors.cstring(
## "ERROR: Shapes do not match!"))
try:
hermitian_Q = other.hermitian
except(AttributeError):
hermitian_Q = False
## extract the local data from the 'other' object
temp_data = self.distributor.extract_local_data(other)
temp_data = operator(temp_data)
else:
## Case 2: other is a real scalar -> preserve hermitianity
elif np.isreal(other) or (self.dtype not in (np.complex, np.complex128,
np.complex256)):
hermitian_Q = self.hermitian
temp_data = operator(other)
## Case 3: other is complex
else:
hermitian_Q = False
temp_data = operator(other)
## write the new data into a new distributed_data_object
temp_d2o = self.copy_empty()
if inplace == True:
temp_d2o = self
else:
temp_d2o = self.copy_empty()
temp_d2o.set_local_data(data=temp_data)
temp_d2o.hermitian = hermitian_Q
return temp_d2o
"""
def __inplace_builtin_helper__(self, operator, other):
## Case 1: other is not a scalar
if not (np.isscalar(other) or np.shape(other) == (1,)):
temp_data = self.distributor.extract_local_data(other)
temp_data = operator(temp_data)
else:
## Case 2: other is a real scalar -> preserve hermitianity
elif np.isreal(other):
hermitian_Q = self.hermitian
temp_data = operator(other)
## Case 3: other is complex
else:
temp_data = operator(other)
self.set_local_data(data=temp_data)
self.hermitian = hermitian_Q
return self
"""
def __add__(self, other):
return self.__builtin_helper__(self.get_local_data().__add__, other)
......@@ -296,8 +316,9 @@ class distributed_data_object(object):
return self.__builtin_helper__(self.get_local_data().__radd__, other)
def __iadd__(self, other):
return self.__inplace_builtin_helper__(self.get_local_data().__iadd__,
other)
return self.__builtin_helper__(self.get_local_data().__iadd__,
other,
inplace = True)
def __sub__(self, other):
return self.__builtin_helper__(self.get_local_data().__sub__, other)
......@@ -306,8 +327,9 @@ class distributed_data_object(object):
return self.__builtin_helper__(self.get_local_data().__rsub__, other)
def __isub__(self, other):
return self.__inplace_builtin_helper__(self.get_local_data().__isub__,
other)
return self.__builtin_helper__(self.get_local_data().__isub__,
other,
inplace = True)
def __div__(self, other):
return self.__builtin_helper__(self.get_local_data().__div__, other)
......@@ -316,8 +338,9 @@ class distributed_data_object(object):
return self.__builtin_helper__(self.get_local_data().__rdiv__, other)
def __idiv__(self, other):
return self.__inplace_builtin_helper__(self.get_local_data().__idiv__,
other)
return self.__builtin_helper__(self.get_local_data().__idiv__,
other,
inplace = True)
def __floordiv__(self, other):
return self.__builtin_helper__(self.get_local_data().__floordiv__,
......@@ -326,8 +349,9 @@ class distributed_data_object(object):
return self.__builtin_helper__(self.get_local_data().__rfloordiv__,
other)
def __ifloordiv__(self, other):
return self.__inplace_builtin_helper__(
self.get_local_data().__ifloordiv__, other)
return self.__builtin_helper__(
self.get_local_data().__ifloordiv__, other,
inplace = True)
def __mul__(self, other):
return self.__builtin_helper__(self.get_local_data().__mul__, other)
......@@ -336,8 +360,9 @@ class distributed_data_object(object):
return self.__builtin_helper__(self.get_local_data().__rmul__, other)
def __imul__(self, other):
return self.__inplace_builtin_helper__(self.get_local_data().__imul__,
other)
return self.__builtin_helper__(self.get_local_data().__imul__,
other,
inplace = True)
def __pow__(self, other):
return self.__builtin_helper__(self.get_local_data().__pow__, other)
......@@ -346,8 +371,9 @@ class distributed_data_object(object):
return self.__builtin_helper__(self.get_local_data().__rpow__, other)
def __ipow__(self, other):
return self.__inplace_builtin_helper__(self.get_local_data().__ipow__,
other)
return self.___builtin_helper__(self.get_local_data().__ipow__,
other,
inplace = True)
def __len__(self):
return self.shape[0]
......@@ -392,24 +418,30 @@ class distributed_data_object(object):
def __setitem__(self, key, data):
self.set_data(data, key)
def _minmaxhelper(self, function, **kwargs):
def _contraction_helper(self, function, **kwargs):
local = function(self.data, **kwargs)
local_list = self.distributor._allgather(local)
global_ = function(local_list, axis=0)
return global_
def amin(self, **kwargs):
return self._minmaxhelper(np.amin, **kwargs)
return self._contraction_helper(np.amin, **kwargs)
def nanmin(self, **kwargs):
return self._minmaxhelper(np.nanmin, **kwargs)
return self._contraction_helper(np.nanmin, **kwargs)
def amax(self, **kwargs):
return self._minmaxhelper(np.amax, **kwargs)
return self._contraction_helper(np.amax, **kwargs)
def nanmax(self, **kwargs):
return self._minmaxhelper(np.nanmax, **kwargs)
return self._contraction_helper(np.nanmax, **kwargs)
def sum(self, **kwargs):
return self._contraction_helper(np.sum, **kwargs)
def prod(self, **kwargs):
return self._contraction_helper(np.prod, **kwargs)
def mean(self, power=1):
## compute the local means and the weights for the mean-mean.
local_mean = np.mean(self.data**power)
......@@ -731,8 +763,13 @@ class distributed_data_object(object):
for i in sliceified:
if i == True:
temp_shape += (1,)
if data.shape[j] == 1:
j +=1
else:
temp_shape += (data.shape[j],)
try:
temp_shape += (data.shape[j],)
except(IndexError):
temp_shape += (1,)
j += 1
## take into account that the sliceified tuple may be too short, because
## of a non-exaustive list of slices
......
## NIFTY (Numerical Information Field Theory) has been developed at the
## Max-Planck-Institute for Astrophysics.
##
## Copyright (C) 2013 Max-Planck-Society
##
## Author: Marco Selig
## Project homepage: <http://www.mpa-garching.mpg.de/ift/nifty/>
##
## This program is free software: you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation, either version 3 of the License, or
## (at your option) any later version.
##
## This program is distributed in the hope that it will be useful,
## but WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
## See the GNU General Public License for more details.
##
## You should have received a copy of the GNU General Public License
## along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
.. __ ____ __
.. /__/ / _/ / /_
.. __ ___ __ / /_ / _/ __ __
.. / _ | / / / _/ / / / / / /
.. / / / / / / / / / /_ / /_/ /
.. /__/ /__/ /__/ /__/ \___/ \___ / tools
.. /______/
This module extends NIFTY with a nifty set of tools including further
operators, namely the :py:class:`invertible_operator` and the
:py:class:`propagator_operator`, and minimization schemes, namely
:py:class:`steepest_descent` and :py:class:`conjugate_gradient`. Those
tools are supposed to support the user in solving information field
theoretical problems (almost) without numerical pain.
"""
from __future__ import division
#from nifty_core import *
import numpy as np
from nifty_about import notification, about
from nifty_core import field
#from nifty_core import space, \
# field
#from operators import operator, \
# diagonal_operator
##=============================================================================
class conjugate_gradient(object):
"""
.. _______ ____ __
.. / _____/ / _ /
.. / /____ __ / /_/ / __
.. \______//__/ \____ //__/ class
.. /______/
NIFTY tool class for conjugate gradient
This tool minimizes :math:`A x = b` with respect to `x` given `A` and
`b` using a conjugate gradient; i.e., a step-by-step minimization
relying on conjugated gradient directions. Further, `A` is assumed to
be a positive definite and self-adjoint operator. The use of a
preconditioner `W` that is roughly the inverse of `A` is optional.
For details on the methodology refer to [#]_, for details on usage and
output, see the notes below.
Parameters
----------
A : {operator, function}
Operator `A` applicable to a field.
b : field
Resulting field of the operation `A(x)`.
W : {operator, function}, *optional*
Operator `W` that is a preconditioner on `A` and is applicable to a
field (default: None).
spam : function, *optional*
Callback function which is given the current `x` and iteration
counter each iteration (default: None).
reset : integer, *optional*
Number of iterations after which to restart; i.e., forget previous
conjugated directions (default: sqrt(b.dim())).
note : bool, *optional*
Indicates whether notes are printed or not (default: False).
See Also
--------
scipy.sparse.linalg.cg
Notes
-----
After initialization by `__init__`, the minimizer is started by calling
it using `__call__`, which takes additional parameters. Notifications,
if enabled, will state the iteration number, current step widths
`alpha` and `beta`, the current relative residual `delta` that is
compared to the tolerance, and the convergence level if changed.
The minimizer will exit in three states: DEAD if alpha becomes
infinite, QUIT if the maximum number of iterations is reached, or DONE
if convergence is achieved. Returned will be the latest `x` and the
latest convergence level, which can evaluate ``True`` for the exit
states QUIT and DONE.
References
----------
.. [#] J. R. Shewchuk, 1994, `"An Introduction to the Conjugate
Gradient Method Without the Agonizing Pain"
<http://www.cs.cmu.edu/~quake-papers/painless-conjugate-gradient.pdf>`_
Examples
--------
>>> b = field(point_space(2), val=[1, 9])
>>> A = diagonal_operator(b.domain, diag=[4, 3])
>>> x,convergence = conjugate_gradient(A, b, note=True)(tol=1E-4, clevel=3)
iteration : 00000001 alpha = 3.3E-01 beta = 1.3E-03 delta = 3.6E-02
iteration : 00000002 alpha = 2.5E-01 beta = 7.6E-04 delta = 1.0E-03
iteration : 00000003 alpha = 3.3E-01 beta = 2.5E-04 delta = 1.6E-05 convergence level : 1
iteration : 00000004 alpha = 2.5E-01 beta = 1.8E-06 delta = 2.1E-08 convergence level : 2
iteration : 00000005 alpha = 2.5E-01 beta = 2.2E-03 delta = 1.0E-09 convergence level : 3
... done.
>>> bool(convergence)
True
>>> x.val # yields 1/4 and 9/3
array([ 0.25, 3. ])
Attributes
----------
A : {operator, function}
Operator `A` applicable to a field.
x : field
Current field.
b : field
Resulting field of the operation `A(x)`.
W : {operator, function}
Operator `W` that is a preconditioner on `A` and is applicable to a
field; can be ``None``.
spam : function
Callback function which is given the current `x` and iteration
counter each iteration; can be ``None``.
reset : integer
Number of iterations after which to restart; i.e., forget previous
conjugated directions (default: sqrt(b.dim())).
note : notification
Notification instance.
"""
def __init__(self,A,b,W=None,spam=None,reset=None,note=False):
"""
Initializes the conjugate_gradient and sets the attributes (except
for `x`).
Parameters
----------
A : {operator, function}
Operator `A` applicable to a field.
b : field
Resulting field of the operation `A(x)`.
W : {operator, function}, *optional*
Operator `W` that is a preconditioner on `A` and is applicable to a
field (default: None).
spam : function, *optional*
Callback function which is given the current `x` and iteration
counter each iteration (default: None).
reset : integer, *optional*
Number of iterations after which to restart; i.e., forget previous
conjugated directions (default: sqrt(b.dim())).
note : bool, *optional*
Indicates whether notes are printed or not (default: False).
"""
if(hasattr(A,"__call__")):
self.A = A ## applies A
else:
raise AttributeError(about._errors.cstring("ERROR: invalid input."))
self.b = b
if(W is None)or(hasattr(W,"__call__")):
self.W = W ## applies W ~ A_inverse
else:
raise AttributeError(about._errors.cstring("ERROR: invalid input."))
self.spam = spam ## serves as callback given x and iteration number
if(reset is None): ## 2 < reset ~ sqrt(dim)
self.reset = max(2,int(np.sqrt(b.domain.dim(split=False))))
else:
self.reset = max(2,int(reset))
self.note = notification(default=bool(note))
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def __call__(self,x0=None,**kwargs): ## > runs cg with/without preconditioner
"""
Runs the conjugate gradient minimization.
Parameters
----------
x0 : field, *optional*
Starting guess for the minimization.
tol : scalar, *optional*
Tolerance specifying convergence; measured by current relative
residual (default: 1E-4).
clevel : integer, *optional*
Number of times the tolerance should be undershot before
exiting (default: 1).
limii : integer, *optional*
Maximum number of iterations performed (default: 10 * b.dim()).
Returns
-------
x : field
Latest `x` of the minimization.
convergence : integer
Latest convergence level indicating whether the minimization
has converged or not.
"""
self.x = field(self.b.domain,val=x0,target=self.b.target)
if(self.W is None):
return self._calc_without(**kwargs)
else:
return self._calc_with(**kwargs)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def _calc_without(self,tol=1E-4,clevel=1,limii=None): ## > runs cg without preconditioner
clevel = int(clevel)
if(limii is None):
limii = 10*self.b.domain.dim(split=False)
else:
limii = int(limii)
r = self.b-self.A(self.x)
d = field(self.b.domain,val=np.copy(r.val),target=self.b.target)
gamma = r.dot(d)
if(gamma==0):
return self.x,clevel+1
delta_ = np.absolute(gamma)**(-0.5)
convergence = 0
ii = 1
while(True):
q = self.A(d)
alpha = gamma/d.dot(q) ## positive definite
if(not np.isfinite(alpha)):
self.note.cprint("\niteration : %08u alpha = NAN\n... dead."%ii)
return self.x,0
self.x += alpha*d
if(np.signbit(np.real(alpha))):
about.warnings.cprint("WARNING: positive definiteness of A violated.")
r = self.b-self.A(self.x)
elif(ii%self.reset==0):
r = self.b-self.A(self.x)
else:
r -= alpha*q
gamma_ = gamma
gamma = r.dot(r)
beta = max(0,gamma/gamma_) ## positive definite
d = r+beta*d
delta = delta_*np.absolute(gamma)**0.5
self.note.cflush("\niteration : %08u alpha = %3.1E beta = %3.1E delta = %3.1E"%(ii,np.real(alpha),np.real(beta),np.real(delta)))
if(gamma==0):
convergence = clevel+1
self.note.cprint(" convergence level : INF\n... done.")
break
elif(np.absolute(delta)<tol):
convergence += 1
self.note.cflush(" convergence level : %u"%convergence)
if(convergence==clevel):
self.note.cprint("\n... done.")
break
else:
convergence = max(0,convergence-1)
if(ii==limii):
self.note.cprint("\n... quit.")
break
if(self.spam is not None):
self.spam(self.x,ii)
ii += 1
if(self.spam is not None):
self.spam(self.x,ii)
return self.x,convergence
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def _calc_with(self,tol=1E-4,clevel=1,limii=None): ## > runs cg with preconditioner
clevel = int(clevel)
if(limii is None):
limii = 10*self.b.domain.dim(split=False)
else:
limii = int(limii)
r = self.b-self.A(self.x)
d = self.W(r)
gamma = r.dot(d)
if(gamma==0):
return self.x,clevel+1
delta_ = np.absolute(gamma)**(-0.5)
convergence = 0
ii = 1
while(True):
q = self.A(d)
alpha = gamma/d.dot(q) ## positive definite
if(not np.isfinite(alpha)):
self.note.cprint("\niteration : %08u alpha = NAN\n... dead."%ii)
return self.x,0
self.x += alpha*d ## update
if(np.signbit(np.real(alpha))):
about.warnings.cprint("WARNING: positive definiteness of A violated.")
r = self.b-self.A(self.x)
elif(ii%self.reset==0):
r = self.b-self.A(self.x)
else:
r -= alpha*q
s = self.W(r)
gamma_ = gamma
gamma = r.dot(s)
if(np.signbit(np.real(gamma))):
about.warnings.cprint("WARNING: positive definiteness of W violated.")
beta = max(0,gamma/gamma_) ## positive definite