Commit 623eb77f authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'NIFTy_5' into unit-tests-for-models

parents d4140adc 50ffbdcc
......@@ -65,7 +65,6 @@ pages:
- NIFTy_4
before_script:
- export MPLBACKEND="agg"
- python setup.py install --user -f
- python3 setup.py install --user -f
......
......@@ -29,6 +29,9 @@ RUN apt-get update && apt-get install -y python-matplotlib python3-matplotlib \
&& python3 -m pip install --upgrade pip && python3 -m pip install jupyter && python -m pip install --upgrade pip && python -m pip install jupyter \
&& rm -rf /var/lib/apt/lists/*
# Set matplotlib backend
ENV MPLBACKEND agg
# Create user (openmpi does not like to be run as root)
RUN useradd -ms /bin/bash testinguser
USER testinguser
......
......@@ -26,6 +26,7 @@ from .models.model import Model
from .models.multi_model import MultiModel
from .models.variable import Variable
from .operators.central_zero_padder import CentralZeroPadder
from .operators.diagonal_operator import DiagonalOperator
from .operators.dof_distributor import DOFDistributor
from .operators.domain_distributor import DomainDistributor
......
import numpy as np
import itertools
from .. import utilities
from .linear_operator import LinearOperator
from ..domain_tuple import DomainTuple
from ..domains.rg_space import RGSpace
from ..field import Field
from .. import dobj
# MR FIXME: for even axis lengths, we probably should split the value at the
# highest frequency.
class CentralZeroPadder(LinearOperator):
def __init__(self, domain, new_shape, space=0):
super(CentralZeroPadder, self).__init__()
self._domain = DomainTuple.make(domain)
self._space = utilities.infer_space(self._domain, space)
dom = self._domain[self._space]
if not isinstance(dom, RGSpace):
raise TypeError("RGSpace required")
if dom.harmonic:
raise TypeError("RGSpace must not be harmonic")
if len(new_shape) != len(dom.shape):
raise ValueError("Shape mismatch")
if any([a < b for a, b in zip(new_shape, dom.shape)]):
raise ValueError("New shape must be larger than old shape")
tgt = RGSpace(new_shape, dom.distances)
self._target = list(self._domain)
self._target[self._space] = tgt
self._target = DomainTuple.make(self._target)
slicer = []
axes = self._target.axes[self._space]
for i in range(len(self._domain.shape)):
if i in axes:
slicer_fw = slice(0, (self._domain.shape[i]+1)//2)
slicer_bw = slice(-1, -1-(self._domain.shape[i]//2), -1)
slicer.append([slicer_fw, slicer_bw])
self.slicer = list(itertools.product(*slicer))
for i in range(len(self.slicer)):
for j in range(len(self._domain.shape)):
if j not in axes:
tmp = list(self.slicer[i])
tmp.insert(j, slice(None))
self.slicer[i] = tmp
@property
def domain(self):
return self._domain
@property
def target(self):
return self._target
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
x = x.val
dax = dobj.distaxis(x)
shp_in = x.shape
shp_out = self._tgt(mode).shape
axes = self._target.axes[self._space]
if dax in axes:
x = dobj.redistribute(x, nodist=axes)
curax = dobj.distaxis(x)
x = dobj.local_data(x)
if mode == self.TIMES:
y = np.zeros(dobj.local_shape(shp_out, curax), dtype=x.dtype)
for i in self.slicer:
y[i] = x[i]
else:
y = np.empty(dobj.local_shape(shp_out, curax), dtype=x.dtype)
for i in self.slicer:
y[i] = x[i]
y = dobj.from_local_data(shp_out, y, distaxis=curax)
if dax in axes:
y = dobj.redistribute(y, dist=dax)
return Field(self._tgt(mode), val=y)
......@@ -25,7 +25,7 @@ from ..compat import *
from ..domain_tuple import DomainTuple
from ..domains.dof_space import DOFSpace
from ..field import Field
from ..utilities import infer_space
from ..utilities import infer_space, special_add_at
from .linear_operator import LinearOperator
......@@ -116,7 +116,7 @@ class DOFDistributor(LinearOperator):
arr = x.local_data
arr = arr.reshape(self._pshape)
oarr = np.zeros(self._hshape, dtype=x.dtype)
np.add.at(oarr, (slice(None), self._dofdex, slice(None)), arr)
oarr = special_add_at(oarr, 1, self._dofdex, arr)
if dobj.distaxis(x.val) in x.domain.axes[self._space]:
oarr = dobj.np_allreduce_sum(oarr).reshape(self._domain.shape)
res = Field.from_global_data(self._domain, oarr)
......
......@@ -27,13 +27,13 @@ from ..domains.power_space import PowerSpace
from ..domains.rg_space import RGSpace
from ..field import Field
from .linear_operator import LinearOperator
from .. import utilities
from ..utilities import infer_space, special_add_at
class ExpTransform(LinearOperator):
def __init__(self, target, dof, space=0):
self._target = DomainTuple.make(target)
self._space = utilities.infer_space(self._target, space)
self._space = infer_space(self._target, space)
tgt = self._target[self._space]
if not ((isinstance(tgt, RGSpace) and tgt.harmonic) or
isinstance(tgt, PowerSpace)):
......@@ -112,8 +112,8 @@ class ExpTransform(LinearOperator):
shp = list(x.shape)
shp[d] = self._tgt(mode).shape[d]
xnew = np.zeros(shp, dtype=x.dtype)
np.add.at(xnew, idx + (self._bindex[d-d0],), x * (1.-wgt))
np.add.at(xnew, idx + (self._bindex[d-d0]+1,), x * wgt)
xnew = special_add_at(xnew, d, self._bindex[d-d0], x*(1.-wgt))
xnew = special_add_at(xnew, d, self._bindex[d-d0]+1, x*wgt)
else: # TIMES
xnew = x[idx + (self._bindex[d-d0],)] * (1.-wgt)
xnew += x[idx + (self._bindex[d-d0]+1,)] * wgt
......
......@@ -30,7 +30,7 @@ from .compat import *
__all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space",
"memo", "NiftyMetaBase", "fft_prep", "hartley", "my_fftn_r2c",
"my_fftn", "my_sum", "my_lincomb_simple", "my_lincomb",
"my_product", "frozendict"]
"my_product", "frozendict", "special_add_at"]
def my_sum(terms):
......@@ -81,7 +81,7 @@ def get_slice_list(shape, axes):
if axes:
if not all(axis < len(shape) for axis in axes):
raise ValueError("axes(axis) does not match shape.")
axes_select = [0 if x in axes else 1 for x, y in enumerate(shape)]
axes_select = [0 if x in axes else 1 for x in range(len(shape))]
axes_iterables = \
[list(range(y)) for x, y in enumerate(shape) if x not in axes]
for index in product(*axes_iterables):
......@@ -334,3 +334,25 @@ class frozendict(collections.Mapping):
h ^= hash((key, value))
self._hash = h
return self._hash
def special_add_at(a, axis, index, b):
if a.dtype != b.dtype:
raise TypeError("data type mismatch")
sz1 = int(np.prod(a.shape[:axis]))
sz3 = int(np.prod(a.shape[axis+1:]))
a2 = a.reshape([sz1, -1, sz3])
b2 = b.reshape([sz1, -1, sz3])
if np.issubdtype(a.dtype, np.complexfloating):
dt2 = a.real.dtype
a2 = a2.view(dt2)
b2 = b2.view(dt2)
sz3 *= 2
for i1 in range(sz1):
for i3 in range(sz3):
a2[i1, :, i3] += np.bincount(index, b2[i1, :, i3],
minlength=a2.shape[1])
if np.issubdtype(a.dtype, np.complexfloating):
a2 = a2.view(a.dtype)
return a2.reshape(a.shape)
......@@ -209,6 +209,14 @@ class Consistency_Tests(unittest.TestCase):
op = ift.FieldZeroPadder(dom, newshape, space)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product([0, 2], [2, 2.7], [np.float64, np.complex128]))
def testZeroPadder2(self, space, factor, dtype):
dom = (ift.RGSpace(10), ift.UnstructuredDomain(13), ift.RGSpace(7, 12),
ift.HPSpace(4))
newshape = [factor*l for l in dom[space].shape]
op = ift.CentralZeroPadder(dom, newshape, space)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product([(ift.RGSpace(10, harmonic=True), 4, 0),
(ift.RGSpace((24, 31), distances=(0.4, 2.34),
harmonic=True), (4, 3), 0),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment