Commit 99ebd010 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

more simplifications

parent ecd9ca14
Pipeline #25092 passed with stages
in 6 minutes and 39 seconds
...@@ -160,7 +160,7 @@ class PowerSpace(StructuredDomain): ...@@ -160,7 +160,7 @@ class PowerSpace(StructuredDomain):
tbb = 0.5*(tmp[:-1]+tmp[1:]) tbb = 0.5*(tmp[:-1]+tmp[1:])
else: else:
tbb = binbounds tbb = binbounds
locdat = np.searchsorted(tbb, dobj.local_data(k_length_array.val)) locdat = np.searchsorted(tbb, k_length_array.local_data)
temp_pindex = dobj.from_local_data( temp_pindex = dobj.from_local_data(
k_length_array.val.shape, locdat, k_length_array.val.shape, locdat,
dobj.distaxis(k_length_array.val)) dobj.distaxis(k_length_array.val))
...@@ -175,7 +175,7 @@ class PowerSpace(StructuredDomain): ...@@ -175,7 +175,7 @@ class PowerSpace(StructuredDomain):
# floating-point weights are present ... # floating-point weights are present ...
temp_k_lengths = np.bincount( temp_k_lengths = np.bincount(
dobj.local_data(temp_pindex).ravel(), dobj.local_data(temp_pindex).ravel(),
weights=dobj.local_data(k_length_array.val).ravel(), weights=k_length_array.local_data.ravel(),
minlength=nbin).astype(np.float64, copy=False) minlength=nbin).astype(np.float64, copy=False)
temp_k_lengths = dobj.np_allreduce_sum(temp_k_lengths) / temp_rho temp_k_lengths = dobj.np_allreduce_sum(temp_k_lengths) / temp_rho
temp_k_lengths.flags.writeable = False temp_k_lengths.flags.writeable = False
......
...@@ -96,7 +96,7 @@ class RGSpace(StructuredDomain): ...@@ -96,7 +96,7 @@ class RGSpace(StructuredDomain):
if (not self.harmonic): if (not self.harmonic):
raise NotImplementedError raise NotImplementedError
out = Field((self,), dtype=np.float64) out = Field((self,), dtype=np.float64)
oloc = dobj.local_data(out.val) oloc = out.local_data
ibegin = dobj.ibegin(out.val) ibegin = dobj.ibegin(out.val)
res = np.arange(oloc.shape[0], dtype=np.float64) + ibegin[0] res = np.arange(oloc.shape[0], dtype=np.float64) + ibegin[0]
res = np.minimum(res, self.shape[0]-res)*self.distances[0] res = np.minimum(res, self.shape[0]-res)*self.distances[0]
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
import numpy as np import numpy as np
from ..field import Field from ..field import Field
from .. import dobj
__all__ = ["consistency_check"] __all__ = ["consistency_check"]
......
...@@ -130,6 +130,10 @@ class Field(object): ...@@ -130,6 +130,10 @@ class Field(object):
def to_global_data(self): def to_global_data(self):
return dobj.to_global_data(self._val) return dobj.to_global_data(self._val)
@property
def local_data(self):
return dobj.local_data(self._val)
def cast_domain(self, new_domain): def cast_domain(self, new_domain):
return Field(new_domain, self._val) return Field(new_domain, self._val)
...@@ -337,8 +341,7 @@ class Field(object): ...@@ -337,8 +341,7 @@ class Field(object):
if dobj.distaxis(self._val) >= 0 and ind == 0: if dobj.distaxis(self._val) >= 0 and ind == 0:
# we need to distribute the weights along axis 0 # we need to distribute the weights along axis 0
wgt = dobj.local_data(dobj.from_global_data(wgt)) wgt = dobj.local_data(dobj.from_global_data(wgt))
lout = dobj.local_data(out.val) out.local_data[()] *= wgt**power
lout *= wgt**power
fct = fct**power fct = fct**power
if fct != 1.: if fct != 1.:
out *= fct out *= fct
...@@ -490,7 +493,7 @@ class Field(object): ...@@ -490,7 +493,7 @@ class Field(object):
raise TypeError("argument must be a Field") raise TypeError("argument must be a Field")
if other._domain != self._domain: if other._domain != self._domain:
raise ValueError("domains are incompatible.") raise ValueError("domains are incompatible.")
dobj.local_data(self.val)[()] = dobj.local_data(other.val)[()] self.local_data[()] = other.local_data[()]
def _binary_helper(self, other, op): def _binary_helper(self, other, op):
# if other is a field, make sure that the domains match # if other is a field, make sure that the domains match
......
...@@ -82,16 +82,16 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -82,16 +82,16 @@ class DiagonalOperator(EndomorphicOperator):
active_axes += self._domain.axes[space_index] active_axes += self._domain.axes[space_index]
if self._spaces[0] == 0: if self._spaces[0] == 0:
self._ldiag = dobj.local_data(self._diagonal.val) self._ldiag = self._diagonal.local_data
else: else:
self._ldiag = dobj.to_global_data(self._diagonal.val) self._ldiag = self._diagonal.to_global_data()
locshape = dobj.local_shape(self._domain.shape, 0) locshape = dobj.local_shape(self._domain.shape, 0)
self._reshaper = [shp if i in active_axes else 1 self._reshaper = [shp if i in active_axes else 1
for i, shp in enumerate(locshape)] for i, shp in enumerate(locshape)]
self._ldiag = self._ldiag.reshape(self._reshaper) self._ldiag = self._ldiag.reshape(self._reshaper)
else: else:
self._ldiag = dobj.local_data(self._diagonal.val) self._ldiag = self._diagonal.local_data
def apply(self, x, mode): def apply(self, x, mode):
self._check_input(x, mode) self._check_input(x, mode)
......
...@@ -48,12 +48,11 @@ class DOFDistributor(LinearOperator): ...@@ -48,12 +48,11 @@ class DOFDistributor(LinearOperator):
nbin = dofdex.max() nbin = dofdex.max()
if partner.scalar_dvol is not None: if partner.scalar_dvol is not None:
wgt = np.bincount(dobj.local_data(dofdex.val).ravel(), wgt = np.bincount(dofdex.local_data.ravel(), minlength=nbin)
minlength=nbin)
wgt = wgt*partner.scalar_dvol wgt = wgt*partner.scalar_dvol
else: else:
dvol = dobj.local_data(partner.dvol) dvol = dobj.local_data(partner.dvol)
wgt = np.bincount(dobj.local_data(dofdex.val).ravel(), wgt = np.bincount(dofdex.local_data.ravel(),
minlength=nbin, weights=dvol) minlength=nbin, weights=dvol)
# The explicit conversion to float64 is necessary because bincount # The explicit conversion to float64 is necessary because bincount
# sometimes returns its result as an integer array, even when # sometimes returns its result as an integer array, even when
...@@ -85,7 +84,7 @@ class DOFDistributor(LinearOperator): ...@@ -85,7 +84,7 @@ class DOFDistributor(LinearOperator):
self._pshape = (presize, self._dofdex.size, postsize) self._pshape = (presize, self._dofdex.size, postsize)
def _adjoint_times(self, x): def _adjoint_times(self, x):
arr = dobj.local_data(x.val) arr = x.local_data
arr = arr.reshape(self._pshape) arr = arr.reshape(self._pshape)
oarr = np.zeros(self._hshape, dtype=x.dtype) oarr = np.zeros(self._hshape, dtype=x.dtype)
np.add.at(oarr, (slice(None), self._dofdex, slice(None)), arr) np.add.at(oarr, (slice(None), self._dofdex, slice(None)), arr)
...@@ -105,9 +104,9 @@ class DOFDistributor(LinearOperator): ...@@ -105,9 +104,9 @@ class DOFDistributor(LinearOperator):
if dobj.distaxis(x.val) in x.domain.axes[self._space]: if dobj.distaxis(x.val) in x.domain.axes[self._space]:
arr = x.to_global_data() arr = x.to_global_data()
else: else:
arr = dobj.local_data(x.val) arr = x.local_data
arr = arr.reshape(self._hshape) arr = arr.reshape(self._hshape)
oarr = dobj.local_data(res.val).reshape(self._pshape) oarr = res.local_data.reshape(self._pshape)
oarr[()] = arr[(slice(None), self._dofdex, slice(None))] oarr[()] = arr[(slice(None), self._dofdex, slice(None))]
return res return res
......
...@@ -77,7 +77,7 @@ class FFTOperator(LinearOperator): ...@@ -77,7 +77,7 @@ class FFTOperator(LinearOperator):
tdom = self._target if x.domain == self._domain else self._domain tdom = self._target if x.domain == self._domain else self._domain
oldax = dobj.distaxis(x.val) oldax = dobj.distaxis(x.val)
if oldax not in axes: # straightforward, no redistribution needed if oldax not in axes: # straightforward, no redistribution needed
ldat = dobj.local_data(x.val) ldat = x.local_data
ldat = utilities.hartley(ldat, axes=axes) ldat = utilities.hartley(ldat, axes=axes)
tmp = dobj.from_local_data(x.val.shape, ldat, distaxis=oldax) tmp = dobj.from_local_data(x.val.shape, ldat, distaxis=oldax)
elif len(axes) < len(x.shape) or len(axes) == 1: elif len(axes) < len(x.shape) or len(axes) == 1:
......
...@@ -104,7 +104,7 @@ class HarmonicTransformOperator(LinearOperator): ...@@ -104,7 +104,7 @@ class HarmonicTransformOperator(LinearOperator):
tdom = self._target if x.domain == self._domain else self._domain tdom = self._target if x.domain == self._domain else self._domain
oldax = dobj.distaxis(x.val) oldax = dobj.distaxis(x.val)
if oldax not in axes: # straightforward, no redistribution needed if oldax not in axes: # straightforward, no redistribution needed
ldat = dobj.local_data(x.val) ldat = x.local_data
ldat = utilities.hartley(ldat, axes=axes) ldat = utilities.hartley(ldat, axes=axes)
tmp = dobj.from_local_data(x.val.shape, ldat, distaxis=oldax) tmp = dobj.from_local_data(x.val.shape, ldat, distaxis=oldax)
elif len(axes) < len(x.shape) or len(axes) == 1: elif len(axes) < len(x.shape) or len(axes) == 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment