From 80dfc14d607db8a152dca4a3c82800025c41e7f8 Mon Sep 17 00:00:00 2001 From: Martin Reinecke Date: Tue, 14 Nov 2017 13:54:41 +0100 Subject: [PATCH] more safe merges from dobj_experiments --- nifty/basic_arithmetics.py | 3 +-- nifty/domain_tuple.py | 2 +- nifty/energies/line_energy.py | 7 +++---- nifty/low_level_library.py | 13 +++++++------ nifty/minimization/conjugate_gradient.py | 4 ++-- .../gradient_norm_controller.py | 4 ++-- .../line_searching/line_search_strong_wolfe.py | 6 +++--- nifty/operators/composed_operator.py | 1 - nifty/operators/response_operator.py | 4 +--- nifty/operators/smoothness_operator.py | 3 +-- nifty/plotting/plot.py | 15 +++++++++------ nifty/spaces/lm_space.py | 3 ++- 12 files changed, 32 insertions(+), 33 deletions(-) diff --git a/nifty/basic_arithmetics.py b/nifty/basic_arithmetics.py index 39a5a576..2ab6158b 100644 --- a/nifty/basic_arithmetics.py +++ b/nifty/basic_arithmetics.py @@ -17,7 +17,6 @@ # and financially supported by the Studienstiftung des deutschen Volkes. from __future__ import division -import numpy as np from .field import Field from . import dobj @@ -29,7 +28,7 @@ def _math_helper(x, function, out): if not isinstance(x, Field): raise TypeError("This function only accepts Field objects.") if out is not None: - if not isinstance(out, Field) or x.domain!=out.domain: + if not isinstance(out, Field) or x.domain != out.domain: raise ValueError("Bad 'out' argument") function(x.val, out=out.val) return out diff --git a/nifty/domain_tuple.py b/nifty/domain_tuple.py index 194b9b0d..87e5abb5 100644 --- a/nifty/domain_tuple.py +++ b/nifty/domain_tuple.py @@ -19,6 +19,7 @@ from functools import reduce from .domain_object import DomainObject + class DomainTuple(object): _tupleCache = {} @@ -122,7 +123,6 @@ class DomainTuple(object): of pixels in the requested space in res[1], and the remaining pixels in res[2]. """ - dims = (dom.dim for dom in self._dom) return (self._accdims[ispace], self._accdims[ispace+1]//self._accdims[ispace], self._accdims[-1]//self._accdims[ispace+1]) diff --git a/nifty/energies/line_energy.py b/nifty/energies/line_energy.py index 7215072b..7454ea8d 100644 --- a/nifty/energies/line_energy.py +++ b/nifty/energies/line_energy.py @@ -16,8 +16,6 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. -from __future__ import print_function - class LineEnergy(object): """ Evaluates an underlying Energy along a certain line direction. @@ -114,6 +112,7 @@ class LineEnergy(object): def directional_derivative(self): res = self.energy.gradient.vdot(self.line_direction) if abs(res.imag) / max(abs(res.real), 1.) > 1e-12: - print("directional derivative has non-negligible " - "imaginary part:", res) + from ..dobj import mprint + mprint("directional derivative has non-negligible " + "imaginary part:", res) return res.real diff --git a/nifty/low_level_library.py b/nifty/low_level_library.py index 82d5f0a9..d71c2841 100644 --- a/nifty/low_level_library.py +++ b/nifty/low_level_library.py @@ -56,6 +56,7 @@ if not special_hartley: _fill_upper_half(tmp, res, axes) return res + def hartley(a, axes=None): # Check if the axes provided are valid given the shape if axes is not None and \ @@ -77,28 +78,28 @@ def hartley(a, axes=None): if use_numba: from numba import complex128 as ncplx, float64 as nflt, vectorize as nvct - @nvct([nflt(nflt,nflt,nflt), ncplx(nflt,ncplx,ncplx)], nopython=True, + @nvct([nflt(nflt, nflt, nflt), ncplx(nflt, ncplx, ncplx)], nopython=True, target="cpu") def _general_axpy(a, x, y): return a*x + y def general_axpy(a, x, y, out): if x.domain != y.domain or x.domain != out.domain: - raise ValueError ("Incompatible domains") + raise ValueError("Incompatible domains") return _general_axpy(a, x.val, y.val, out.val) else: - def general_axpy(a,x,y,out): + def general_axpy(a, x, y, out): if x.domain != y.domain or x.domain != out.domain: - raise ValueError ("Incompatible domains") + raise ValueError("Incompatible domains") if out is x: if a != 1.: - out*=a + out *= a out += y elif out is y: - if a!=1.: + if a != 1.: out += a*x else: out += x diff --git a/nifty/minimization/conjugate_gradient.py b/nifty/minimization/conjugate_gradient.py index ed458d94..e3b71168 100644 --- a/nifty/minimization/conjugate_gradient.py +++ b/nifty/minimization/conjugate_gradient.py @@ -18,10 +18,10 @@ from __future__ import division from .minimizer import Minimizer -import numpy as np from .. import Field from ..low_level_library import general_axpy + class ConjugateGradient(Minimizer): """ Implementation of the Conjugate Gradient scheme. @@ -82,7 +82,7 @@ class ConjugateGradient(Minimizer): if previous_gamma == 0: return energy, controller.CONVERGED - tpos = Field(d.domain,dtype=d.dtype) # temporary buffer + tpos = Field(d.domain, dtype=d.dtype) # temporary buffer while True: q = energy.curvature(d) ddotq = d.vdot(q).real diff --git a/nifty/minimization/iteration_controlling/gradient_norm_controller.py b/nifty/minimization/iteration_controlling/gradient_norm_controller.py index fa46dde4..077cc72f 100644 --- a/nifty/minimization/iteration_controlling/gradient_norm_controller.py +++ b/nifty/minimization/iteration_controlling/gradient_norm_controller.py @@ -16,8 +16,8 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. -from __future__ import print_function from .iteration_controller import IterationController +from ... import dobj class GradientNormController(IterationController): @@ -64,7 +64,7 @@ class GradientNormController(IterationController): msg += " energy=" + str(energy.value) msg += " gradnorm=" + str(energy.gradient_norm) msg += " clvl=" + str(self._ccount) - print(msg) + dobj.mprint(msg) # self.logger.info(msg) # Are we done? diff --git a/nifty/minimization/line_searching/line_search_strong_wolfe.py b/nifty/minimization/line_searching/line_search_strong_wolfe.py index 9bb79420..3658cad4 100644 --- a/nifty/minimization/line_searching/line_search_strong_wolfe.py +++ b/nifty/minimization/line_searching/line_search_strong_wolfe.py @@ -16,13 +16,13 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. -from __future__ import print_function from __future__ import division from builtins import range import numpy as np from .line_search import LineSearch from ...energies import LineEnergy +from ... import dobj class LineSearchStrongWolfe(LineSearch): @@ -164,7 +164,7 @@ class LineSearchStrongWolfe(LineSearch): phi_alpha0 = phi_alpha1 phiprime_alpha0 = phiprime_alpha1 else: - print("max iterations reached") + dobj.mprint("max iterations reached") return le_alpha1.energy return result_energy @@ -261,7 +261,7 @@ class LineSearchStrongWolfe(LineSearch): phiprime_alphaj) else: - print("The line search algorithm (zoom) did not converge.") + dobj.mprint("The line search algorithm (zoom) did not converge.") return le_alphaj def _cubicmin(self, a, fa, fpa, b, fb, c, fc): diff --git a/nifty/operators/composed_operator.py b/nifty/operators/composed_operator.py index 8acff7a8..143adcbe 100644 --- a/nifty/operators/composed_operator.py +++ b/nifty/operators/composed_operator.py @@ -18,7 +18,6 @@ from builtins import range from .linear_operator import LinearOperator -from .. import DomainTuple class ComposedOperator(LinearOperator): diff --git a/nifty/operators/response_operator.py b/nifty/operators/response_operator.py index 7cd30b1b..5dd3e1f7 100644 --- a/nifty/operators/response_operator.py +++ b/nifty/operators/response_operator.py @@ -1,7 +1,5 @@ from builtins import range -from .. import Field,\ - FieldArray,\ - DomainTuple +from .. import Field, FieldArray, DomainTuple from .linear_operator import LinearOperator from .fft_smoothing_operator import FFTSmoothingOperator from .composed_operator import ComposedOperator diff --git a/nifty/operators/smoothness_operator.py b/nifty/operators/smoothness_operator.py index ca1882dd..30a54135 100644 --- a/nifty/operators/smoothness_operator.py +++ b/nifty/operators/smoothness_operator.py @@ -1,7 +1,6 @@ -from ..spaces.power_space import PowerSpace from .endomorphic_operator import EndomorphicOperator from .laplace_operator import LaplaceOperator -from .. import Field, DomainTuple +from .. import Field class SmoothnessOperator(EndomorphicOperator): diff --git a/nifty/plotting/plot.py b/nifty/plotting/plot.py index f5599ec9..ba618339 100644 --- a/nifty/plotting/plot.py +++ b/nifty/plotting/plot.py @@ -1,6 +1,6 @@ from __future__ import division import numpy as np -from ..import Field, RGSpace, HPSpace, GLSpace, PowerSpace +from ..import Field, RGSpace, HPSpace, GLSpace, PowerSpace, dobj import os # relevant properties: @@ -45,6 +45,8 @@ def _find_closest(A, target): def _makeplot(name): import matplotlib.pyplot as plt + if dobj.rank != 0: + return if name is None: plt.show() return @@ -173,7 +175,7 @@ def plot(f, **kwargs): npoints = dom.shape[0] dist = dom.distances[0] xcoord = np.arange(npoints, dtype=np.float64)*dist - ycoord = f.val + ycoord = dobj.to_global_data(f.val) plt.plot(xcoord, ycoord) _limit_xy(**kwargs) _makeplot(kwargs.get("name")) @@ -185,7 +187,8 @@ def plot(f, **kwargs): dy = dom.distances[1] xc = np.arange(nx, dtype=np.float64)*dx yc = np.arange(ny, dtype=np.float64)*dy - im = ax.imshow(f.val, extent=[xc[0], xc[-1], yc[0], yc[-1]], + im = ax.imshow(dobj.to_global_data(f.val), + extent=[xc[0], xc[-1], yc[0], yc[-1]], vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"), cmap=cmap, origin="lower") # from mpl_toolkits.axes_grid1 import make_axes_locatable @@ -198,7 +201,7 @@ def plot(f, **kwargs): return elif isinstance(dom, PowerSpace): xcoord = dom.k_lengths - ycoord = f.val + ycoord = dobj.to_global_data(f.val) plt.xscale('log') plt.yscale('log') plt.title('power') @@ -215,7 +218,7 @@ def plot(f, **kwargs): ptg[:, 0] = theta ptg[:, 1] = phi base = pyHealpix.Healpix_Base(int(np.sqrt(f.val.size//12)), "RING") - res[mask] = f.val[base.ang2pix(ptg)] + res[mask] = dobj.to_global_data(f.val)[base.ang2pix(ptg)] plt.axis('off') plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"), cmap=cmap, origin="lower") @@ -231,7 +234,7 @@ def plot(f, **kwargs): ilat = _find_closest(dec, theta) ilon = _find_closest(ra, phi) ilon = np.where(ilon == dom.nlon, 0, ilon) - res[mask] = f.val[ilat*dom.nlon + ilon] + res[mask] = dobj.to_global_data(f.val)[ilat*dom.nlon + ilon] plt.axis('off') plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"), diff --git a/nifty/spaces/lm_space.py b/nifty/spaces/lm_space.py index 31993cd5..95f88dbf 100644 --- a/nifty/spaces/lm_space.py +++ b/nifty/spaces/lm_space.py @@ -21,6 +21,7 @@ import numpy as np from .space import Space from .. import Field from ..basic_arithmetics import exp +from .. import dobj class LMSpace(Space): @@ -102,7 +103,7 @@ class LMSpace(Space): for m in range(1, mmax+1): ldist[idx:idx+2*(lmax+1-m)] = tmp[2*m:] idx += 2*(lmax+1-m) - return Field((self,), ldist) + return Field((self,), dobj.from_global_data(ldist)) def get_unique_k_lengths(self): return np.arange(self.lmax+1, dtype=np.float64) -- GitLab