Commit c3dd0ba2 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

fixes

parent 3858e173
...@@ -391,7 +391,7 @@ def from_global_data(arr, sum_up=False, distaxis=0): ...@@ -391,7 +391,7 @@ def from_global_data(arr, sum_up=False, distaxis=0):
lo, hi = _shareRange(arr.shape[distaxis], ntask, rank) lo, hi = _shareRange(arr.shape[distaxis], ntask, rank)
sl = [slice(None)]*len(arr.shape) sl = [slice(None)]*len(arr.shape)
sl[distaxis] = slice(lo, hi) sl[distaxis] = slice(lo, hi)
return data_object(arr.shape, arr[sl], distaxis) return data_object(arr.shape, arr[tuple(sl)], distaxis)
def to_global_data(arr): def to_global_data(arr):
...@@ -467,7 +467,7 @@ def redistribute(arr, dist=None, nodist=None): ...@@ -467,7 +467,7 @@ def redistribute(arr, dist=None, nodist=None):
lo, hi = _shareRange(arr.shape[dist], ntask, i) lo, hi = _shareRange(arr.shape[dist], ntask, i)
sslice[dist] = slice(lo, hi) sslice[dist] = slice(lo, hi)
ssz[i] = ssz0*(hi-lo) ssz[i] = ssz0*(hi-lo)
sbuf[ofs:ofs+ssz[i]] = arr._data[sslice].flat sbuf[ofs:ofs+ssz[i]] = arr._data[tuple(sslice)].flat
ofs += ssz[i] ofs += ssz[i]
rsz[i] = rsz0*_shareSize(arr.shape[arr._distaxis], ntask, i) rsz[i] = rsz0*_shareSize(arr.shape[arr._distaxis], ntask, i)
ssz *= arr._data.itemsize ssz *= arr._data.itemsize
...@@ -489,7 +489,7 @@ def redistribute(arr, dist=None, nodist=None): ...@@ -489,7 +489,7 @@ def redistribute(arr, dist=None, nodist=None):
lo, hi = _shareRange(arr.shape[arr._distaxis], ntask, i) lo, hi = _shareRange(arr.shape[arr._distaxis], ntask, i)
rslice[arr._distaxis] = slice(lo, hi) rslice[arr._distaxis] = slice(lo, hi)
sz = rsz[i]//arr._data.itemsize sz = rsz[i]//arr._data.itemsize
arrnew[rslice].flat = rbuf[ofs:ofs+sz] arrnew[tuple(rslice)].flat = rbuf[ofs:ofs+sz]
ofs += sz ofs += sz
arrnew = from_local_data(arr.shape, arrnew, distaxis=dist) arrnew = from_local_data(arr.shape, arrnew, distaxis=dist)
return arrnew return arrnew
......
...@@ -477,7 +477,7 @@ class Field(object): ...@@ -477,7 +477,7 @@ class Field(object):
swgt = self.scalar_weight(spaces) swgt = self.scalar_weight(spaces)
if swgt is not None: if swgt is not None:
res = self.sum(spaces) res = self.sum(spaces)
res = res * swgt res = res*swgt
return res return res
tmp = self.weight(1, spaces=spaces) tmp = self.weight(1, spaces=spaces)
return tmp.sum(spaces) return tmp.sum(spaces)
......
...@@ -58,7 +58,7 @@ class OuterProduct(LinearOperator): ...@@ -58,7 +58,7 @@ class OuterProduct(LinearOperator):
def apply(self, x, mode): def apply(self, x, mode):
self._check_input(x, mode) self._check_input(x, mode)
if mode == self.TIMES: if mode == self.TIMES:
return Field(self._target, np.multiply.outer(self._field.to_global_data(), x.to_global_data())) return Field.from_global_data(self._target, np.multiply.outer(self._field.to_global_data(), x.to_global_data()))
axes = len(self._field.shape) axes = len(self._field.shape)
return Field(self._domain, val=np.tensordot(self._field.to_global_data(), x.to_global_data(), axes)) return Field.from_global_data(self._domain, val=np.tensordot(self._field.to_global_data(), x.to_global_data(), axes))
...@@ -70,6 +70,7 @@ class SumReductionOperator(LinearOperator): ...@@ -70,6 +70,7 @@ class SumReductionOperator(LinearOperator):
else: else:
for i in self._spaces: for i in self._spaces:
ns = self._domain._dom[i] ns = self._domain._dom[i]
# FIXME: nested use of "i"
ps = tuple(i - 1 for i in ns.shape) ps = tuple(i - 1 for i in ns.shape)
dtfi = DomainTupleFieldInserter(domain=self._target, new_space=ns, index=i, position=ps) dtfi = DomainTupleFieldInserter(domain=self._target, new_space=ns, index=i, position=ps)
x = dtfi(x) x = dtfi(x)
...@@ -78,13 +79,15 @@ class SumReductionOperator(LinearOperator): ...@@ -78,13 +79,15 @@ class SumReductionOperator(LinearOperator):
class IntegralReductionOperator(LinearOperator): class IntegralReductionOperator(LinearOperator):
def __init__(self, domain, spaces=None): def __init__(self, domain, spaces=None):
self._spaces = spaces self._domain = DomainTuple.make(domain)
self._domain = domain self._spaces = utilities.parse_spaces(spaces, len(self._domain))
if spaces is None: if len(self._spaces) == len(self._domain):
self._spaces = None
if self._spaces is None:
self._target = DomainTuple.scalar_domain() self._target = DomainTuple.scalar_domain()
else: else:
self._target = makeDomain(tuple(dom for i, dom in enumerate(self._domain) if not(i == spaces))) self._target = makeDomain(tuple(dom for i, dom in enumerate(self._domain) if not(i in self._spaces)))
self._marg_space = makeDomain(tuple(dom for i, dom in enumerate(self._domain) if (i == spaces))) self._marg_space = makeDomain(tuple(dom for i, dom in enumerate(self._domain) if (i in self._spaces)))
self._capability = self.TIMES | self.ADJOINT_TIMES self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode): def apply(self, x, mode):
...@@ -111,6 +114,7 @@ class IntegralReductionOperator(LinearOperator): ...@@ -111,6 +114,7 @@ class IntegralReductionOperator(LinearOperator):
sp = self._spaces sp = self._spaces
for i in sp: for i in sp:
ns = self._domain._dom[i] ns = self._domain._dom[i]
# FIXME: nested use of "i"
ps = tuple(i - 1 for i in ns.shape) ps = tuple(i - 1 for i in ns.shape)
dtfi = DomainTupleFieldInserter(domain=self._target, new_space=ns, index=i, position=ps) dtfi = DomainTupleFieldInserter(domain=self._target, new_space=ns, index=i, position=ps)
x = dtfi(x) x = dtfi(x)
......
...@@ -269,7 +269,7 @@ def my_fftn_r2c(a, axes=None): ...@@ -269,7 +269,7 @@ def my_fftn_r2c(a, axes=None):
lastaxis = axes[-1] lastaxis = axes[-1]
ntmplast = tmp.shape[lastaxis] ntmplast = tmp.shape[lastaxis]
slice1 = [slice(None)]*lastaxis + [slice(0, ntmplast)] slice1 = [slice(None)]*lastaxis + [slice(0, ntmplast)]
res[slice1] = tmp res[tuple(slice1)] = tmp
def _fill_upper_half_complex(tmp, res, axes): def _fill_upper_half_complex(tmp, res, axes):
lastaxis = axes[-1] lastaxis = axes[-1]
...@@ -282,9 +282,9 @@ def my_fftn_r2c(a, axes=None): ...@@ -282,9 +282,9 @@ def my_fftn_r2c(a, axes=None):
slice1[i] = slice(1, None) slice1[i] = slice(1, None)
slice2[i] = slice(None, 0, -1) slice2[i] = slice(None, 0, -1)
# np.conjugate(tmp[slice2], out=res[slice1]) # np.conjugate(tmp[slice2], out=res[slice1])
res[slice1] = np.conjugate(tmp[slice2]) res[tuple(slice1)] = np.conjugate(tmp[tuple(slice2)])
for i, ax in enumerate(axes[:-1]): for i, ax in enumerate(axes[:-1]):
dim1 = [slice(None)]*ax + [slice(0, 1)] dim1 = tuple([slice(None)]*ax + [slice(0, 1)])
axes2 = axes[:i] + axes[i+1:] axes2 = axes[:i] + axes[i+1:]
_fill_upper_half_complex(tmp[dim1], res[dim1], axes2) _fill_upper_half_complex(tmp[dim1], res[dim1], axes2)
......
...@@ -142,7 +142,7 @@ class Test_Functionality(unittest.TestCase): ...@@ -142,7 +142,7 @@ class Test_Functionality(unittest.TestCase):
m1 = ift.Field.full(x1, .5) m1 = ift.Field.full(x1, .5)
m2 = ift.Field.full(x2, 3.) m2 = ift.Field.full(x2, 3.)
res = m1.outer(m2) res = m1.outer(m2)
assert_allclose(res.local_data, np.full((9, 3,), 1.5)) assert_allclose(res.to_global_data(), np.full((9, 3,), 1.5))
def test_dataconv(self): def test_dataconv(self):
s1 = ift.RGSpace((10,)) s1 = ift.RGSpace((10,))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment