Commit c3dd0ba2 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

fixes

parent 3858e173
......@@ -391,7 +391,7 @@ def from_global_data(arr, sum_up=False, distaxis=0):
lo, hi = _shareRange(arr.shape[distaxis], ntask, rank)
sl = [slice(None)]*len(arr.shape)
sl[distaxis] = slice(lo, hi)
return data_object(arr.shape, arr[sl], distaxis)
return data_object(arr.shape, arr[tuple(sl)], distaxis)
def to_global_data(arr):
......@@ -467,7 +467,7 @@ def redistribute(arr, dist=None, nodist=None):
lo, hi = _shareRange(arr.shape[dist], ntask, i)
sslice[dist] = slice(lo, hi)
ssz[i] = ssz0*(hi-lo)
sbuf[ofs:ofs+ssz[i]] = arr._data[sslice].flat
sbuf[ofs:ofs+ssz[i]] = arr._data[tuple(sslice)].flat
ofs += ssz[i]
rsz[i] = rsz0*_shareSize(arr.shape[arr._distaxis], ntask, i)
ssz *= arr._data.itemsize
......@@ -489,7 +489,7 @@ def redistribute(arr, dist=None, nodist=None):
lo, hi = _shareRange(arr.shape[arr._distaxis], ntask, i)
rslice[arr._distaxis] = slice(lo, hi)
sz = rsz[i]//arr._data.itemsize
arrnew[rslice].flat = rbuf[ofs:ofs+sz]
arrnew[tuple(rslice)].flat = rbuf[ofs:ofs+sz]
ofs += sz
arrnew = from_local_data(arr.shape, arrnew, distaxis=dist)
return arrnew
......
......@@ -477,7 +477,7 @@ class Field(object):
swgt = self.scalar_weight(spaces)
if swgt is not None:
res = self.sum(spaces)
res = res * swgt
res = res*swgt
return res
tmp = self.weight(1, spaces=spaces)
return tmp.sum(spaces)
......
......@@ -58,7 +58,7 @@ class OuterProduct(LinearOperator):
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return Field(self._target, np.multiply.outer(self._field.to_global_data(), x.to_global_data()))
return Field.from_global_data(self._target, np.multiply.outer(self._field.to_global_data(), x.to_global_data()))
axes = len(self._field.shape)
return Field(self._domain, val=np.tensordot(self._field.to_global_data(), x.to_global_data(), axes))
return Field.from_global_data(self._domain, val=np.tensordot(self._field.to_global_data(), x.to_global_data(), axes))
......@@ -70,6 +70,7 @@ class SumReductionOperator(LinearOperator):
else:
for i in self._spaces:
ns = self._domain._dom[i]
# FIXME: nested use of "i"
ps = tuple(i - 1 for i in ns.shape)
dtfi = DomainTupleFieldInserter(domain=self._target, new_space=ns, index=i, position=ps)
x = dtfi(x)
......@@ -78,13 +79,15 @@ class SumReductionOperator(LinearOperator):
class IntegralReductionOperator(LinearOperator):
def __init__(self, domain, spaces=None):
self._spaces = spaces
self._domain = domain
if spaces is None:
self._domain = DomainTuple.make(domain)
self._spaces = utilities.parse_spaces(spaces, len(self._domain))
if len(self._spaces) == len(self._domain):
self._spaces = None
if self._spaces is None:
self._target = DomainTuple.scalar_domain()
else:
self._target = makeDomain(tuple(dom for i, dom in enumerate(self._domain) if not(i == spaces)))
self._marg_space = makeDomain(tuple(dom for i, dom in enumerate(self._domain) if (i == spaces)))
self._target = makeDomain(tuple(dom for i, dom in enumerate(self._domain) if not(i in self._spaces)))
self._marg_space = makeDomain(tuple(dom for i, dom in enumerate(self._domain) if (i in self._spaces)))
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
......@@ -111,6 +114,7 @@ class IntegralReductionOperator(LinearOperator):
sp = self._spaces
for i in sp:
ns = self._domain._dom[i]
# FIXME: nested use of "i"
ps = tuple(i - 1 for i in ns.shape)
dtfi = DomainTupleFieldInserter(domain=self._target, new_space=ns, index=i, position=ps)
x = dtfi(x)
......
......@@ -269,7 +269,7 @@ def my_fftn_r2c(a, axes=None):
lastaxis = axes[-1]
ntmplast = tmp.shape[lastaxis]
slice1 = [slice(None)]*lastaxis + [slice(0, ntmplast)]
res[slice1] = tmp
res[tuple(slice1)] = tmp
def _fill_upper_half_complex(tmp, res, axes):
lastaxis = axes[-1]
......@@ -282,9 +282,9 @@ def my_fftn_r2c(a, axes=None):
slice1[i] = slice(1, None)
slice2[i] = slice(None, 0, -1)
# np.conjugate(tmp[slice2], out=res[slice1])
res[slice1] = np.conjugate(tmp[slice2])
res[tuple(slice1)] = np.conjugate(tmp[tuple(slice2)])
for i, ax in enumerate(axes[:-1]):
dim1 = [slice(None)]*ax + [slice(0, 1)]
dim1 = tuple([slice(None)]*ax + [slice(0, 1)])
axes2 = axes[:i] + axes[i+1:]
_fill_upper_half_complex(tmp[dim1], res[dim1], axes2)
......
......@@ -142,7 +142,7 @@ class Test_Functionality(unittest.TestCase):
m1 = ift.Field.full(x1, .5)
m2 = ift.Field.full(x2, 3.)
res = m1.outer(m2)
assert_allclose(res.local_data, np.full((9, 3,), 1.5))
assert_allclose(res.to_global_data(), np.full((9, 3,), 1.5))
def test_dataconv(self):
s1 = ift.RGSpace((10,))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment