Commit 99ebd010 authored by Martin Reinecke's avatar Martin Reinecke

more simplifications

parent ecd9ca14
Pipeline #25092 passed with stages
in 6 minutes and 39 seconds
......@@ -160,7 +160,7 @@ class PowerSpace(StructuredDomain):
tbb = 0.5*(tmp[:-1]+tmp[1:])
else:
tbb = binbounds
locdat = np.searchsorted(tbb, dobj.local_data(k_length_array.val))
locdat = np.searchsorted(tbb, k_length_array.local_data)
temp_pindex = dobj.from_local_data(
k_length_array.val.shape, locdat,
dobj.distaxis(k_length_array.val))
......@@ -175,7 +175,7 @@ class PowerSpace(StructuredDomain):
# floating-point weights are present ...
temp_k_lengths = np.bincount(
dobj.local_data(temp_pindex).ravel(),
weights=dobj.local_data(k_length_array.val).ravel(),
weights=k_length_array.local_data.ravel(),
minlength=nbin).astype(np.float64, copy=False)
temp_k_lengths = dobj.np_allreduce_sum(temp_k_lengths) / temp_rho
temp_k_lengths.flags.writeable = False
......
......@@ -96,7 +96,7 @@ class RGSpace(StructuredDomain):
if (not self.harmonic):
raise NotImplementedError
out = Field((self,), dtype=np.float64)
oloc = dobj.local_data(out.val)
oloc = out.local_data
ibegin = dobj.ibegin(out.val)
res = np.arange(oloc.shape[0], dtype=np.float64) + ibegin[0]
res = np.minimum(res, self.shape[0]-res)*self.distances[0]
......
......@@ -18,7 +18,6 @@
import numpy as np
from ..field import Field
from .. import dobj
__all__ = ["consistency_check"]
......
......@@ -130,6 +130,10 @@ class Field(object):
def to_global_data(self):
return dobj.to_global_data(self._val)
@property
def local_data(self):
return dobj.local_data(self._val)
def cast_domain(self, new_domain):
return Field(new_domain, self._val)
......@@ -337,8 +341,7 @@ class Field(object):
if dobj.distaxis(self._val) >= 0 and ind == 0:
# we need to distribute the weights along axis 0
wgt = dobj.local_data(dobj.from_global_data(wgt))
lout = dobj.local_data(out.val)
lout *= wgt**power
out.local_data[()] *= wgt**power
fct = fct**power
if fct != 1.:
out *= fct
......@@ -490,7 +493,7 @@ class Field(object):
raise TypeError("argument must be a Field")
if other._domain != self._domain:
raise ValueError("domains are incompatible.")
dobj.local_data(self.val)[()] = dobj.local_data(other.val)[()]
self.local_data[()] = other.local_data[()]
def _binary_helper(self, other, op):
# if other is a field, make sure that the domains match
......
......@@ -82,16 +82,16 @@ class DiagonalOperator(EndomorphicOperator):
active_axes += self._domain.axes[space_index]
if self._spaces[0] == 0:
self._ldiag = dobj.local_data(self._diagonal.val)
self._ldiag = self._diagonal.local_data
else:
self._ldiag = dobj.to_global_data(self._diagonal.val)
self._ldiag = self._diagonal.to_global_data()
locshape = dobj.local_shape(self._domain.shape, 0)
self._reshaper = [shp if i in active_axes else 1
for i, shp in enumerate(locshape)]
self._ldiag = self._ldiag.reshape(self._reshaper)
else:
self._ldiag = dobj.local_data(self._diagonal.val)
self._ldiag = self._diagonal.local_data
def apply(self, x, mode):
self._check_input(x, mode)
......
......@@ -48,12 +48,11 @@ class DOFDistributor(LinearOperator):
nbin = dofdex.max()
if partner.scalar_dvol is not None:
wgt = np.bincount(dobj.local_data(dofdex.val).ravel(),
minlength=nbin)
wgt = np.bincount(dofdex.local_data.ravel(), minlength=nbin)
wgt = wgt*partner.scalar_dvol
else:
dvol = dobj.local_data(partner.dvol)
wgt = np.bincount(dobj.local_data(dofdex.val).ravel(),
wgt = np.bincount(dofdex.local_data.ravel(),
minlength=nbin, weights=dvol)
# The explicit conversion to float64 is necessary because bincount
# sometimes returns its result as an integer array, even when
......@@ -85,7 +84,7 @@ class DOFDistributor(LinearOperator):
self._pshape = (presize, self._dofdex.size, postsize)
def _adjoint_times(self, x):
arr = dobj.local_data(x.val)
arr = x.local_data
arr = arr.reshape(self._pshape)
oarr = np.zeros(self._hshape, dtype=x.dtype)
np.add.at(oarr, (slice(None), self._dofdex, slice(None)), arr)
......@@ -105,9 +104,9 @@ class DOFDistributor(LinearOperator):
if dobj.distaxis(x.val) in x.domain.axes[self._space]:
arr = x.to_global_data()
else:
arr = dobj.local_data(x.val)
arr = x.local_data
arr = arr.reshape(self._hshape)
oarr = dobj.local_data(res.val).reshape(self._pshape)
oarr = res.local_data.reshape(self._pshape)
oarr[()] = arr[(slice(None), self._dofdex, slice(None))]
return res
......
......@@ -77,7 +77,7 @@ class FFTOperator(LinearOperator):
tdom = self._target if x.domain == self._domain else self._domain
oldax = dobj.distaxis(x.val)
if oldax not in axes: # straightforward, no redistribution needed
ldat = dobj.local_data(x.val)
ldat = x.local_data
ldat = utilities.hartley(ldat, axes=axes)
tmp = dobj.from_local_data(x.val.shape, ldat, distaxis=oldax)
elif len(axes) < len(x.shape) or len(axes) == 1:
......
......@@ -104,7 +104,7 @@ class HarmonicTransformOperator(LinearOperator):
tdom = self._target if x.domain == self._domain else self._domain
oldax = dobj.distaxis(x.val)
if oldax not in axes: # straightforward, no redistribution needed
ldat = dobj.local_data(x.val)
ldat = x.local_data
ldat = utilities.hartley(ldat, axes=axes)
tmp = dobj.from_local_data(x.val.shape, ldat, distaxis=oldax)
elif len(axes) < len(x.shape) or len(axes) == 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment