Commit 03b8712a authored by Martin Reinecke's avatar Martin Reinecke
Browse files

improvements

parent 315483ac
Pipeline #21594 passed with stage
in 5 minutes and 4 seconds
......@@ -313,8 +313,7 @@ class Field(object):
new_shape[self.domain.axes[ind][0]:
self.domain.axes[ind][-1]+1] = wgt.shape
wgt = wgt.reshape(new_shape)
# FIXME only temporary
if ind == 0:
if ind == 0: # we need to distribute the weights along axis 0
wgt = dobj.local_data(dobj.from_global_data(wgt))
out *= wgt**power
fct = fct**power
......
......@@ -95,38 +95,42 @@ class DiagonalOperator(EndomorphicOperator):
if self._spaces == tuple(range(len(self._domain.domains))):
self._spaces = None # shortcut
self._diagonal = diagonal.copy()
if self._spaces is not None:
active_axes = []
for space_index in self._spaces:
active_axes += self._domain.axes[space_index]
if self._spaces[0] == 0:
self._ldiag = dobj.local_data(self._diagonal.val)
else:
self._ldiag = dobj.to_global_data(self._diagonal.val)
locshape = dobj.local_shape(self._domain.shape, 0)
self._reshaper = [shp if i in active_axes else 1
for i, shp in enumerate(self._domain.shape)]
for i, shp in enumerate(locshape)]
self._ldiag = self._ldiag.reshape(self._reshaper)
else:
self._ldiag = dobj.local_data(self._diagonal.val)
self._diagonal = diagonal.copy()
self._self_adjoint = None
self._unitary = None
def _times(self, x):
return self._times_helper(x, self._diagonal)
return Field(x.domain, val=x.val*self._ldiag)
def _adjoint_times(self, x):
return self._times_helper(x, self._diagonal.conj())
return Field(x.domain, val=x.val*self._ldiag.conj())
def _inverse_times(self, x):
return self._times_helper(x, 1./self._diagonal)
return Field(x.domain, val=x.val/self._ldiag)
def _adjoint_inverse_times(self, x):
return self._times_helper(x, 1./self._diagonal.conj())
return Field(x.domain, val=x.val/self._ldiag.conj())
def diagonal(self):
""" Returns the diagonal of the Operator.
Returns
-------
out : Field
The diagonal of the Operator.
"""
""" Returns the diagonal of the Operator."""
return self._diagonal.copy()
@property
......@@ -147,12 +151,3 @@ class DiagonalOperator(EndomorphicOperator):
if self._unitary is None:
self._unitary = (abs(self._diagonal.val) == 1.).all()
return self._unitary
def _times_helper(self, x, diag):
if self._spaces is None:
return diag*x
reshaped_local_diagonal = np.reshape(dobj.to_global_data(diag.val), self._reshaper)
if 0 in self._spaces:
reshaped_local_diagonal = dobj.local_data(dobj.from_global_data(reshaped_local_diagonal))
return Field(x.domain, val=x.val*reshaped_local_diagonal)
......@@ -62,34 +62,29 @@ class RGRGTransformation(Transformation):
axes = x.domain.axes[self.space]
p2h = x.domain == self.pdom
tdom = self.hdom if p2h else self.pdom
oldax = dobj.distaxis(x.val)
if dobj.distaxis(x.val) in axes:
tmpax = (dobj.distaxis(x.val),)
tmp = dobj.redistribute(x.val, nodist=tmpax)
tmp = dobj.redistribute(x.val, nodist=(oldax,))
newax = dobj.distaxis(tmp)
ldat = dobj.local_data(tmp)
if len(axes) == 1: # only one transform needed
ldat = hartley(ldat, axes=tmpax)
tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=dobj.distaxis(tmp))
tmp = dobj.redistribute(tmp, dist=tmpax[0])
else: # two separate transforms
ldat = fftn(ldat, axes=tmpax)
tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=dobj.distaxis(tmp))
tmp = dobj.redistribute(tmp, dist=tmpax[0])
tmpax = tuple(i for i in axes if i not in tmpax)
ldat = hartley(ldat, axes=(oldax,))
tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=newax)
tmp = dobj.redistribute(tmp, dist=oldax)
else: # two separate transforms needed, "real" FFT required
ldat = fftn(ldat, axes=(oldax,))
tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=newax)
tmp = dobj.redistribute(tmp, dist=oldax)
rem_axes = tuple(i for i in axes if i != oldax)
ldat = dobj.local_data(tmp)
ldat = fftn(ldat, axes=tmpax)
ldat = fftn(ldat, axes=rem_axes)
ldat = ldat.real+ldat.imag
tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=dobj.distaxis(tmp))
Tval = Field(tdom, tmp)
tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=oldax)
else:
ldat = dobj.local_data(x.val)
# these two alternatives are equivalent, with the second being faster
if False:
ldat = fftn(ldat, axes=axes)
ldat = ldat.real+ldat.imag
else:
ldat = hartley(ldat, axes=axes)
tmp = dobj.from_local_data(x.val.shape, ldat, distaxis=dobj.distaxis(x.val))
Tval = Field(tdom, tmp)
ldat = hartley(ldat, axes=axes)
tmp = dobj.from_local_data(x.val.shape, ldat, distaxis=oldax)
Tval = Field(tdom, tmp)
fct = self.fct_p2h if p2h else self.fct_h2p
if fct != 1:
Tval *= fct
......@@ -144,21 +139,14 @@ class SphericalTransformation(Transformation):
distaxis = dobj.distaxis(tval)
p2h = x.domain == self.pdom
tdom = self.hdom if p2h else self.pdom
func = self._slice_p2h if p2h else self._slice_h2p
idat = dobj.local_data(tval)
if p2h:
odat = np.empty(dobj.local_shape(self.hdom.shape, distaxis=distaxis), dtype=x.dtype)
for slice in utilities.get_slice_list(idat.shape, axes):
odat[slice] = self._slice_p2h(idat[slice])
odat = dobj.from_local_data(self.hdom.shape, odat, distaxis)
if distaxis != dobj.distaxis(x.val):
odat = dobj.redistribute(odat, dist=dobj.distaxis(x.val))
return Field(self.hdom, odat)
else:
odat = np.empty(dobj.local_shape(self.pdom.shape, distaxis=distaxis), dtype=x.dtype)
for slice in utilities.get_slice_list(idat.shape, axes):
odat[slice] = self._slice_h2p(idat[slice])
odat = dobj.from_local_data(self.pdom.shape, odat, distaxis)
if distaxis != dobj.distaxis(x.val):
odat = dobj.redistribute(odat, dist=dobj.distaxis(x.val))
return Field(self.pdom, odat)
odat = np.empty(dobj.local_shape(tdom.shape, distaxis=distaxis),
dtype=x.dtype)
for slice in utilities.get_slice_list(idat.shape, axes):
odat[slice] = func(idat[slice])
odat = dobj.from_local_data(tdom.shape, odat, distaxis)
if distaxis != dobj.distaxis(x.val):
odat = dobj.redistribute(odat, dist=dobj.distaxis(x.val))
return Field(tdom, odat)
......@@ -50,54 +50,45 @@ class PowerProjectionOperator(LinearOperator):
tgt[self._space] = power_space
self._target = DomainTuple.make(tgt)
# shopping list:
# 1) make sure that pindex is distributed in the same way as in the Field living on self.domain.
# 2) if the operated-on space is not distributed (i.e. if it is not space 0), _no_ further communication is necessary
def _times(self, x):
# harmonic field goes in
# pindex must be distributed in the same way as harmonic field
# power field must be available in full
pindex = self._target[self._space].pindex
if dobj.distaxis(x.val) in x.domain.axes[self._space]: # the distributed axis is part of the projected space
if dobj.default_distaxis() in self.domain.axes[self._space]:
pindex = dobj.local_data(pindex)
else: # pindex must be available fully on every task
pindex = dobj.to_global_data(pindex)
pindex.reshape((1, pindex.size, 1))
self._pindex = pindex.ravel()
firstaxis = self._domain.axes[self._space][0]
lastaxis = self._domain.axes[self._space][-1]
arrshape = dobj.local_shape(self._domain.shape, 0)
presize = np.prod(arrshape[0:firstaxis], dtype=np.int)
postsize = np.prod(arrshape[lastaxis+1:], dtype=np.int)
self._hshape = (presize, self._target[self._space].shape[0], postsize)
self._pshape = (presize, self._pindex.size, postsize)
def _times(self, x):
arr = dobj.local_data(x.weight(1).val)
firstaxis = x.domain.axes[self._space][0]
lastaxis = x.domain.axes[self._space][-1]
presize = np.prod(arr.shape[0:firstaxis], dtype=np.int)
postsize = np.prod(arr.shape[lastaxis+1:], dtype=np.int)
arr = arr.reshape((presize, pindex.size, postsize))
oarr = np.zeros((presize, self._target[self._space].shape[0], postsize), dtype=x.dtype)
np.add.at(oarr, (slice(None), pindex.ravel(), slice(None)), arr)
arr = arr.reshape(self._pshape)
oarr = np.zeros(self._hshape, dtype=x.dtype)
np.add.at(oarr, (slice(None), self._pindex, slice(None)), arr)
if dobj.distaxis(x.val) in x.domain.axes[self._space]:
oarr = dobj.np_allreduce_sum(oarr)
oarr = oarr.reshape(self._target.shape)
oarr = dobj.np_allreduce_sum(oarr).reshape(self._target.shape)
res = Field(self._target, dobj.from_global_data(oarr))
else:
oarr = oarr.reshape(dobj.local_shape(self._target.shape, dobj.distaxis(x.val)))
res = Field(self._target, dobj.from_local_data(self._target.shape, oarr, dobj.default_distaxis()))
oarr = oarr.reshape(dobj.local_shape(self._target.shape,
dobj.distaxis(x.val)))
res = Field(self._target,
dobj.from_local_data(self._target.shape, oarr,
dobj.default_distaxis()))
return res.weight(-1, spaces=self._space)
def _adjoint_times(self, x):
pindex = self._target[self._space].pindex
res = Field.empty(self._domain, dtype=x.dtype)
if dobj.distaxis(x.val) in x.domain.axes[self._space]: # the distributed axis is part of the projected space
pindex = dobj.local_data(pindex)
if dobj.distaxis(x.val) in x.domain.axes[self._space]:
arr = dobj.to_global_data(x.val)
else:
pindex = dobj.to_global_data(pindex)
arr = dobj.local_data(x.val)
pindex = pindex.reshape((1, pindex.size, 1))
firstaxis = x.domain.axes[self._space][0]
lastaxis = x.domain.axes[self._space][-1]
presize = np.prod(arr.shape[0:firstaxis], dtype=np.int)
postsize = np.prod(arr.shape[lastaxis+1:], dtype=np.int)
arr = arr.reshape((presize, self._target[self._space].shape[0], postsize))
oarr = dobj.local_data(res.val).reshape((presize, pindex.size, postsize))
oarr[()] = arr[(slice(None), pindex.ravel(), slice(None))]
arr = arr.reshape(self._hshape)
oarr = dobj.local_data(res.val).reshape(self._pshape)
oarr[()] = arr[(slice(None), self._pindex, slice(None))]
return res
@property
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment