Commit 80dfc14d authored by Martin Reinecke's avatar Martin Reinecke

more safe merges from dobj_experiments

parent 1b7e9fb4
Pipeline #21564 passed with stage
in 4 minutes and 21 seconds
......@@ -17,7 +17,6 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import division
import numpy as np
from .field import Field
from . import dobj
......@@ -29,7 +28,7 @@ def _math_helper(x, function, out):
if not isinstance(x, Field):
raise TypeError("This function only accepts Field objects.")
if out is not None:
if not isinstance(out, Field) or x.domain!=out.domain:
if not isinstance(out, Field) or x.domain != out.domain:
raise ValueError("Bad 'out' argument")
function(x.val, out=out.val)
return out
......
......@@ -19,6 +19,7 @@
from functools import reduce
from .domain_object import DomainObject
class DomainTuple(object):
_tupleCache = {}
......@@ -122,7 +123,6 @@ class DomainTuple(object):
of pixels in the requested space in res[1], and the remaining pixels in
res[2].
"""
dims = (dom.dim for dom in self._dom)
return (self._accdims[ispace],
self._accdims[ispace+1]//self._accdims[ispace],
self._accdims[-1]//self._accdims[ispace+1])
......@@ -16,8 +16,6 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import print_function
class LineEnergy(object):
""" Evaluates an underlying Energy along a certain line direction.
......@@ -114,6 +112,7 @@ class LineEnergy(object):
def directional_derivative(self):
res = self.energy.gradient.vdot(self.line_direction)
if abs(res.imag) / max(abs(res.real), 1.) > 1e-12:
print("directional derivative has non-negligible "
"imaginary part:", res)
from ..dobj import mprint
mprint("directional derivative has non-negligible "
"imaginary part:", res)
return res.real
......@@ -56,6 +56,7 @@ if not special_hartley:
_fill_upper_half(tmp, res, axes)
return res
def hartley(a, axes=None):
# Check if the axes provided are valid given the shape
if axes is not None and \
......@@ -77,28 +78,28 @@ def hartley(a, axes=None):
if use_numba:
from numba import complex128 as ncplx, float64 as nflt, vectorize as nvct
@nvct([nflt(nflt,nflt,nflt), ncplx(nflt,ncplx,ncplx)], nopython=True,
@nvct([nflt(nflt, nflt, nflt), ncplx(nflt, ncplx, ncplx)], nopython=True,
target="cpu")
def _general_axpy(a, x, y):
return a*x + y
def general_axpy(a, x, y, out):
if x.domain != y.domain or x.domain != out.domain:
raise ValueError ("Incompatible domains")
raise ValueError("Incompatible domains")
return _general_axpy(a, x.val, y.val, out.val)
else:
def general_axpy(a,x,y,out):
def general_axpy(a, x, y, out):
if x.domain != y.domain or x.domain != out.domain:
raise ValueError ("Incompatible domains")
raise ValueError("Incompatible domains")
if out is x:
if a != 1.:
out*=a
out *= a
out += y
elif out is y:
if a!=1.:
if a != 1.:
out += a*x
else:
out += x
......
......@@ -18,10 +18,10 @@
from __future__ import division
from .minimizer import Minimizer
import numpy as np
from .. import Field
from ..low_level_library import general_axpy
class ConjugateGradient(Minimizer):
""" Implementation of the Conjugate Gradient scheme.
......@@ -82,7 +82,7 @@ class ConjugateGradient(Minimizer):
if previous_gamma == 0:
return energy, controller.CONVERGED
tpos = Field(d.domain,dtype=d.dtype) # temporary buffer
tpos = Field(d.domain, dtype=d.dtype) # temporary buffer
while True:
q = energy.curvature(d)
ddotq = d.vdot(q).real
......
......@@ -16,8 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import print_function
from .iteration_controller import IterationController
from ... import dobj
class GradientNormController(IterationController):
......@@ -64,7 +64,7 @@ class GradientNormController(IterationController):
msg += " energy=" + str(energy.value)
msg += " gradnorm=" + str(energy.gradient_norm)
msg += " clvl=" + str(self._ccount)
print(msg)
dobj.mprint(msg)
# self.logger.info(msg)
# Are we done?
......
......@@ -16,13 +16,13 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import print_function
from __future__ import division
from builtins import range
import numpy as np
from .line_search import LineSearch
from ...energies import LineEnergy
from ... import dobj
class LineSearchStrongWolfe(LineSearch):
......@@ -164,7 +164,7 @@ class LineSearchStrongWolfe(LineSearch):
phi_alpha0 = phi_alpha1
phiprime_alpha0 = phiprime_alpha1
else:
print("max iterations reached")
dobj.mprint("max iterations reached")
return le_alpha1.energy
return result_energy
......@@ -261,7 +261,7 @@ class LineSearchStrongWolfe(LineSearch):
phiprime_alphaj)
else:
print("The line search algorithm (zoom) did not converge.")
dobj.mprint("The line search algorithm (zoom) did not converge.")
return le_alphaj
def _cubicmin(self, a, fa, fpa, b, fb, c, fc):
......
......@@ -18,7 +18,6 @@
from builtins import range
from .linear_operator import LinearOperator
from .. import DomainTuple
class ComposedOperator(LinearOperator):
......
from builtins import range
from .. import Field,\
FieldArray,\
DomainTuple
from .. import Field, FieldArray, DomainTuple
from .linear_operator import LinearOperator
from .fft_smoothing_operator import FFTSmoothingOperator
from .composed_operator import ComposedOperator
......
from ..spaces.power_space import PowerSpace
from .endomorphic_operator import EndomorphicOperator
from .laplace_operator import LaplaceOperator
from .. import Field, DomainTuple
from .. import Field
class SmoothnessOperator(EndomorphicOperator):
......
from __future__ import division
import numpy as np
from ..import Field, RGSpace, HPSpace, GLSpace, PowerSpace
from ..import Field, RGSpace, HPSpace, GLSpace, PowerSpace, dobj
import os
# relevant properties:
......@@ -45,6 +45,8 @@ def _find_closest(A, target):
def _makeplot(name):
import matplotlib.pyplot as plt
if dobj.rank != 0:
return
if name is None:
plt.show()
return
......@@ -173,7 +175,7 @@ def plot(f, **kwargs):
npoints = dom.shape[0]
dist = dom.distances[0]
xcoord = np.arange(npoints, dtype=np.float64)*dist
ycoord = f.val
ycoord = dobj.to_global_data(f.val)
plt.plot(xcoord, ycoord)
_limit_xy(**kwargs)
_makeplot(kwargs.get("name"))
......@@ -185,7 +187,8 @@ def plot(f, **kwargs):
dy = dom.distances[1]
xc = np.arange(nx, dtype=np.float64)*dx
yc = np.arange(ny, dtype=np.float64)*dy
im = ax.imshow(f.val, extent=[xc[0], xc[-1], yc[0], yc[-1]],
im = ax.imshow(dobj.to_global_data(f.val),
extent=[xc[0], xc[-1], yc[0], yc[-1]],
vmin=kwargs.get("zmin"),
vmax=kwargs.get("zmax"), cmap=cmap, origin="lower")
# from mpl_toolkits.axes_grid1 import make_axes_locatable
......@@ -198,7 +201,7 @@ def plot(f, **kwargs):
return
elif isinstance(dom, PowerSpace):
xcoord = dom.k_lengths
ycoord = f.val
ycoord = dobj.to_global_data(f.val)
plt.xscale('log')
plt.yscale('log')
plt.title('power')
......@@ -215,7 +218,7 @@ def plot(f, **kwargs):
ptg[:, 0] = theta
ptg[:, 1] = phi
base = pyHealpix.Healpix_Base(int(np.sqrt(f.val.size//12)), "RING")
res[mask] = f.val[base.ang2pix(ptg)]
res[mask] = dobj.to_global_data(f.val)[base.ang2pix(ptg)]
plt.axis('off')
plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"),
cmap=cmap, origin="lower")
......@@ -231,7 +234,7 @@ def plot(f, **kwargs):
ilat = _find_closest(dec, theta)
ilon = _find_closest(ra, phi)
ilon = np.where(ilon == dom.nlon, 0, ilon)
res[mask] = f.val[ilat*dom.nlon + ilon]
res[mask] = dobj.to_global_data(f.val)[ilat*dom.nlon + ilon]
plt.axis('off')
plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"),
......
......@@ -21,6 +21,7 @@ import numpy as np
from .space import Space
from .. import Field
from ..basic_arithmetics import exp
from .. import dobj
class LMSpace(Space):
......@@ -102,7 +103,7 @@ class LMSpace(Space):
for m in range(1, mmax+1):
ldist[idx:idx+2*(lmax+1-m)] = tmp[2*m:]
idx += 2*(lmax+1-m)
return Field((self,), ldist)
return Field((self,), dobj.from_global_data(ldist))
def get_unique_k_lengths(self):
return np.arange(self.lmax+1, dtype=np.float64)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment