Commit 80dfc14d authored by Martin Reinecke's avatar Martin Reinecke

more safe merges from dobj_experiments

parent 1b7e9fb4
Pipeline #21564 passed with stage
in 4 minutes and 21 seconds
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import division from __future__ import division
import numpy as np
from .field import Field from .field import Field
from . import dobj from . import dobj
...@@ -29,7 +28,7 @@ def _math_helper(x, function, out): ...@@ -29,7 +28,7 @@ def _math_helper(x, function, out):
if not isinstance(x, Field): if not isinstance(x, Field):
raise TypeError("This function only accepts Field objects.") raise TypeError("This function only accepts Field objects.")
if out is not None: if out is not None:
if not isinstance(out, Field) or x.domain!=out.domain: if not isinstance(out, Field) or x.domain != out.domain:
raise ValueError("Bad 'out' argument") raise ValueError("Bad 'out' argument")
function(x.val, out=out.val) function(x.val, out=out.val)
return out return out
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
from functools import reduce from functools import reduce
from .domain_object import DomainObject from .domain_object import DomainObject
class DomainTuple(object): class DomainTuple(object):
_tupleCache = {} _tupleCache = {}
...@@ -122,7 +123,6 @@ class DomainTuple(object): ...@@ -122,7 +123,6 @@ class DomainTuple(object):
of pixels in the requested space in res[1], and the remaining pixels in of pixels in the requested space in res[1], and the remaining pixels in
res[2]. res[2].
""" """
dims = (dom.dim for dom in self._dom)
return (self._accdims[ispace], return (self._accdims[ispace],
self._accdims[ispace+1]//self._accdims[ispace], self._accdims[ispace+1]//self._accdims[ispace],
self._accdims[-1]//self._accdims[ispace+1]) self._accdims[-1]//self._accdims[ispace+1])
...@@ -16,8 +16,6 @@ ...@@ -16,8 +16,6 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import print_function
class LineEnergy(object): class LineEnergy(object):
""" Evaluates an underlying Energy along a certain line direction. """ Evaluates an underlying Energy along a certain line direction.
...@@ -114,6 +112,7 @@ class LineEnergy(object): ...@@ -114,6 +112,7 @@ class LineEnergy(object):
def directional_derivative(self): def directional_derivative(self):
res = self.energy.gradient.vdot(self.line_direction) res = self.energy.gradient.vdot(self.line_direction)
if abs(res.imag) / max(abs(res.real), 1.) > 1e-12: if abs(res.imag) / max(abs(res.real), 1.) > 1e-12:
print("directional derivative has non-negligible " from ..dobj import mprint
"imaginary part:", res) mprint("directional derivative has non-negligible "
"imaginary part:", res)
return res.real return res.real
...@@ -56,6 +56,7 @@ if not special_hartley: ...@@ -56,6 +56,7 @@ if not special_hartley:
_fill_upper_half(tmp, res, axes) _fill_upper_half(tmp, res, axes)
return res return res
def hartley(a, axes=None): def hartley(a, axes=None):
# Check if the axes provided are valid given the shape # Check if the axes provided are valid given the shape
if axes is not None and \ if axes is not None and \
...@@ -77,28 +78,28 @@ def hartley(a, axes=None): ...@@ -77,28 +78,28 @@ def hartley(a, axes=None):
if use_numba: if use_numba:
from numba import complex128 as ncplx, float64 as nflt, vectorize as nvct from numba import complex128 as ncplx, float64 as nflt, vectorize as nvct
@nvct([nflt(nflt,nflt,nflt), ncplx(nflt,ncplx,ncplx)], nopython=True, @nvct([nflt(nflt, nflt, nflt), ncplx(nflt, ncplx, ncplx)], nopython=True,
target="cpu") target="cpu")
def _general_axpy(a, x, y): def _general_axpy(a, x, y):
return a*x + y return a*x + y
def general_axpy(a, x, y, out): def general_axpy(a, x, y, out):
if x.domain != y.domain or x.domain != out.domain: if x.domain != y.domain or x.domain != out.domain:
raise ValueError ("Incompatible domains") raise ValueError("Incompatible domains")
return _general_axpy(a, x.val, y.val, out.val) return _general_axpy(a, x.val, y.val, out.val)
else: else:
def general_axpy(a,x,y,out): def general_axpy(a, x, y, out):
if x.domain != y.domain or x.domain != out.domain: if x.domain != y.domain or x.domain != out.domain:
raise ValueError ("Incompatible domains") raise ValueError("Incompatible domains")
if out is x: if out is x:
if a != 1.: if a != 1.:
out*=a out *= a
out += y out += y
elif out is y: elif out is y:
if a!=1.: if a != 1.:
out += a*x out += a*x
else: else:
out += x out += x
......
...@@ -18,10 +18,10 @@ ...@@ -18,10 +18,10 @@
from __future__ import division from __future__ import division
from .minimizer import Minimizer from .minimizer import Minimizer
import numpy as np
from .. import Field from .. import Field
from ..low_level_library import general_axpy from ..low_level_library import general_axpy
class ConjugateGradient(Minimizer): class ConjugateGradient(Minimizer):
""" Implementation of the Conjugate Gradient scheme. """ Implementation of the Conjugate Gradient scheme.
...@@ -82,7 +82,7 @@ class ConjugateGradient(Minimizer): ...@@ -82,7 +82,7 @@ class ConjugateGradient(Minimizer):
if previous_gamma == 0: if previous_gamma == 0:
return energy, controller.CONVERGED return energy, controller.CONVERGED
tpos = Field(d.domain,dtype=d.dtype) # temporary buffer tpos = Field(d.domain, dtype=d.dtype) # temporary buffer
while True: while True:
q = energy.curvature(d) q = energy.curvature(d)
ddotq = d.vdot(q).real ddotq = d.vdot(q).real
......
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import print_function
from .iteration_controller import IterationController from .iteration_controller import IterationController
from ... import dobj
class GradientNormController(IterationController): class GradientNormController(IterationController):
...@@ -64,7 +64,7 @@ class GradientNormController(IterationController): ...@@ -64,7 +64,7 @@ class GradientNormController(IterationController):
msg += " energy=" + str(energy.value) msg += " energy=" + str(energy.value)
msg += " gradnorm=" + str(energy.gradient_norm) msg += " gradnorm=" + str(energy.gradient_norm)
msg += " clvl=" + str(self._ccount) msg += " clvl=" + str(self._ccount)
print(msg) dobj.mprint(msg)
# self.logger.info(msg) # self.logger.info(msg)
# Are we done? # Are we done?
......
...@@ -16,13 +16,13 @@ ...@@ -16,13 +16,13 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import print_function
from __future__ import division from __future__ import division
from builtins import range from builtins import range
import numpy as np import numpy as np
from .line_search import LineSearch from .line_search import LineSearch
from ...energies import LineEnergy from ...energies import LineEnergy
from ... import dobj
class LineSearchStrongWolfe(LineSearch): class LineSearchStrongWolfe(LineSearch):
...@@ -164,7 +164,7 @@ class LineSearchStrongWolfe(LineSearch): ...@@ -164,7 +164,7 @@ class LineSearchStrongWolfe(LineSearch):
phi_alpha0 = phi_alpha1 phi_alpha0 = phi_alpha1
phiprime_alpha0 = phiprime_alpha1 phiprime_alpha0 = phiprime_alpha1
else: else:
print("max iterations reached") dobj.mprint("max iterations reached")
return le_alpha1.energy return le_alpha1.energy
return result_energy return result_energy
...@@ -261,7 +261,7 @@ class LineSearchStrongWolfe(LineSearch): ...@@ -261,7 +261,7 @@ class LineSearchStrongWolfe(LineSearch):
phiprime_alphaj) phiprime_alphaj)
else: else:
print("The line search algorithm (zoom) did not converge.") dobj.mprint("The line search algorithm (zoom) did not converge.")
return le_alphaj return le_alphaj
def _cubicmin(self, a, fa, fpa, b, fb, c, fc): def _cubicmin(self, a, fa, fpa, b, fb, c, fc):
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
from builtins import range from builtins import range
from .linear_operator import LinearOperator from .linear_operator import LinearOperator
from .. import DomainTuple
class ComposedOperator(LinearOperator): class ComposedOperator(LinearOperator):
......
from builtins import range from builtins import range
from .. import Field,\ from .. import Field, FieldArray, DomainTuple
FieldArray,\
DomainTuple
from .linear_operator import LinearOperator from .linear_operator import LinearOperator
from .fft_smoothing_operator import FFTSmoothingOperator from .fft_smoothing_operator import FFTSmoothingOperator
from .composed_operator import ComposedOperator from .composed_operator import ComposedOperator
......
from ..spaces.power_space import PowerSpace
from .endomorphic_operator import EndomorphicOperator from .endomorphic_operator import EndomorphicOperator
from .laplace_operator import LaplaceOperator from .laplace_operator import LaplaceOperator
from .. import Field, DomainTuple from .. import Field
class SmoothnessOperator(EndomorphicOperator): class SmoothnessOperator(EndomorphicOperator):
......
from __future__ import division from __future__ import division
import numpy as np import numpy as np
from ..import Field, RGSpace, HPSpace, GLSpace, PowerSpace from ..import Field, RGSpace, HPSpace, GLSpace, PowerSpace, dobj
import os import os
# relevant properties: # relevant properties:
...@@ -45,6 +45,8 @@ def _find_closest(A, target): ...@@ -45,6 +45,8 @@ def _find_closest(A, target):
def _makeplot(name): def _makeplot(name):
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
if dobj.rank != 0:
return
if name is None: if name is None:
plt.show() plt.show()
return return
...@@ -173,7 +175,7 @@ def plot(f, **kwargs): ...@@ -173,7 +175,7 @@ def plot(f, **kwargs):
npoints = dom.shape[0] npoints = dom.shape[0]
dist = dom.distances[0] dist = dom.distances[0]
xcoord = np.arange(npoints, dtype=np.float64)*dist xcoord = np.arange(npoints, dtype=np.float64)*dist
ycoord = f.val ycoord = dobj.to_global_data(f.val)
plt.plot(xcoord, ycoord) plt.plot(xcoord, ycoord)
_limit_xy(**kwargs) _limit_xy(**kwargs)
_makeplot(kwargs.get("name")) _makeplot(kwargs.get("name"))
...@@ -185,7 +187,8 @@ def plot(f, **kwargs): ...@@ -185,7 +187,8 @@ def plot(f, **kwargs):
dy = dom.distances[1] dy = dom.distances[1]
xc = np.arange(nx, dtype=np.float64)*dx xc = np.arange(nx, dtype=np.float64)*dx
yc = np.arange(ny, dtype=np.float64)*dy yc = np.arange(ny, dtype=np.float64)*dy
im = ax.imshow(f.val, extent=[xc[0], xc[-1], yc[0], yc[-1]], im = ax.imshow(dobj.to_global_data(f.val),
extent=[xc[0], xc[-1], yc[0], yc[-1]],
vmin=kwargs.get("zmin"), vmin=kwargs.get("zmin"),
vmax=kwargs.get("zmax"), cmap=cmap, origin="lower") vmax=kwargs.get("zmax"), cmap=cmap, origin="lower")
# from mpl_toolkits.axes_grid1 import make_axes_locatable # from mpl_toolkits.axes_grid1 import make_axes_locatable
...@@ -198,7 +201,7 @@ def plot(f, **kwargs): ...@@ -198,7 +201,7 @@ def plot(f, **kwargs):
return return
elif isinstance(dom, PowerSpace): elif isinstance(dom, PowerSpace):
xcoord = dom.k_lengths xcoord = dom.k_lengths
ycoord = f.val ycoord = dobj.to_global_data(f.val)
plt.xscale('log') plt.xscale('log')
plt.yscale('log') plt.yscale('log')
plt.title('power') plt.title('power')
...@@ -215,7 +218,7 @@ def plot(f, **kwargs): ...@@ -215,7 +218,7 @@ def plot(f, **kwargs):
ptg[:, 0] = theta ptg[:, 0] = theta
ptg[:, 1] = phi ptg[:, 1] = phi
base = pyHealpix.Healpix_Base(int(np.sqrt(f.val.size//12)), "RING") base = pyHealpix.Healpix_Base(int(np.sqrt(f.val.size//12)), "RING")
res[mask] = f.val[base.ang2pix(ptg)] res[mask] = dobj.to_global_data(f.val)[base.ang2pix(ptg)]
plt.axis('off') plt.axis('off')
plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"), plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"),
cmap=cmap, origin="lower") cmap=cmap, origin="lower")
...@@ -231,7 +234,7 @@ def plot(f, **kwargs): ...@@ -231,7 +234,7 @@ def plot(f, **kwargs):
ilat = _find_closest(dec, theta) ilat = _find_closest(dec, theta)
ilon = _find_closest(ra, phi) ilon = _find_closest(ra, phi)
ilon = np.where(ilon == dom.nlon, 0, ilon) ilon = np.where(ilon == dom.nlon, 0, ilon)
res[mask] = f.val[ilat*dom.nlon + ilon] res[mask] = dobj.to_global_data(f.val)[ilat*dom.nlon + ilon]
plt.axis('off') plt.axis('off')
plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"), plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"),
......
...@@ -21,6 +21,7 @@ import numpy as np ...@@ -21,6 +21,7 @@ import numpy as np
from .space import Space from .space import Space
from .. import Field from .. import Field
from ..basic_arithmetics import exp from ..basic_arithmetics import exp
from .. import dobj
class LMSpace(Space): class LMSpace(Space):
...@@ -102,7 +103,7 @@ class LMSpace(Space): ...@@ -102,7 +103,7 @@ class LMSpace(Space):
for m in range(1, mmax+1): for m in range(1, mmax+1):
ldist[idx:idx+2*(lmax+1-m)] = tmp[2*m:] ldist[idx:idx+2*(lmax+1-m)] = tmp[2*m:]
idx += 2*(lmax+1-m) idx += 2*(lmax+1-m)
return Field((self,), ldist) return Field((self,), dobj.from_global_data(ldist))
def get_unique_k_lengths(self): def get_unique_k_lengths(self):
return np.arange(self.lmax+1, dtype=np.float64) return np.arange(self.lmax+1, dtype=np.float64)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment