Commit a3b2f458 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'misc_changes' into 'NIFTy_5'

Various changes from the "redesign" branch

See merge request ift/nifty-dev!78
parents bb5c7a43 bdb90c94
......@@ -61,6 +61,8 @@ class data_object(object):
self._shape = tuple(shape)
if len(self._shape) == 0:
distaxis = -1
if not isinstance(data, np.ndarray):
data = np.full((), data)
self._distaxis = distaxis
self._data = data
if local_shape(self._shape, self._distaxis) != self._data.shape:
......@@ -262,7 +264,7 @@ def empty_like(a, dtype=None):
def vdot(a, b):
tmp = np.array(np.vdot(a._data, b._data))
if a._distaxis==-1:
if a._distaxis == -1:
return tmp[()]
res = np.empty((), dtype=tmp.dtype)
_comm.Allreduce(tmp, res, MPI.SUM)
......@@ -311,7 +313,7 @@ def from_object(object, dtype, copy, set_locked):
# algorithm.
def from_random(random_type, shape, dtype=np.float64, **kwargs):
generator_function = getattr(Random, random_type)
if shape == ():
if len(shape) == 0:
ldat = generator_function(dtype=dtype, shape=shape, **kwargs)
ldat = _comm.bcast(ldat)
return from_local_data(shape, ldat, distaxis=-1)
......@@ -460,15 +462,16 @@ def redistribute(arr, dist=None, nodist=None):
rbuf = rbuf.reshape(local_shape(arr.shape, dist))
arrnew = from_local_data(arr.shape, rbuf, distaxis=dist)
else:
arrnew = empty(arr.shape, dtype=arr.dtype, distaxis=dist)
arrnew = np.empty(local_shape(arr.shape, dist), dtype=arr.dtype)
rslice = [slice(None)]*arr._data.ndim
ofs = 0
for i in range(ntask):
lo, hi = _shareRange(arr.shape[arr._distaxis], ntask, i)
rslice[arr._distaxis] = slice(lo, hi)
sz = rsz[i]//arr._data.itemsize
arrnew._data[rslice].flat = rbuf[ofs:ofs+sz]
arrnew[rslice].flat = rbuf[ofs:ofs+sz]
ofs += sz
arrnew = from_local_data(arr.shape, arrnew, distaxis=dist)
return arrnew
......@@ -497,15 +500,15 @@ def transpose(arr):
r_msg = [rbuf, (rsz, rdisp), MPI.BYTE]
_comm.Alltoallv(s_msg, r_msg)
del sbuf # free memory
arrnew = empty((arr.shape[1], arr.shape[0]), dtype=arr.dtype, distaxis=0)
ofs = 0
sz2 = _shareSize(arr.shape[1], ntask, rank)
arrnew = np.empty((sz2, arr.shape[0]), dtype=arr.dtype)
ofs = 0
for i in range(ntask):
lo, hi = _shareRange(arr.shape[0], ntask, i)
sz = rsz[i]//arr._data.itemsize
arrnew._data[:, lo:hi] = rbuf[ofs:ofs+sz].reshape(hi-lo, sz2).T
arrnew[:, lo:hi] = rbuf[ofs:ofs+sz].reshape(hi-lo, sz2).T
ofs += sz
return arrnew
return from_local_data((arr.shape[1], arr.shape[0]), arrnew, 0)
def default_distaxis():
......
......@@ -37,6 +37,7 @@ class DomainTuple(object):
via the factory function :attr:`make`!
"""
_tupleCache = {}
_scalarDomain = None
def __init__(self, domain, _callingfrommake=False):
if not _callingfrommake:
......@@ -150,3 +151,9 @@ class DomainTuple(object):
for i in self:
res += "\n" + str(i)
return res
@staticmethod
def scalar_domain():
if DomainTuple._scalarDomain is None:
DomainTuple._scalarDomain = DomainTuple.make(())
return DomainTuple._scalarDomain
......@@ -50,7 +50,10 @@ class Field(object):
if not isinstance(domain, DomainTuple):
raise TypeError("domain must be of type DomainTuple")
if not isinstance(val, dobj.data_object):
raise TypeError("val must be of type dobj.data_object")
if np.isscalar(val):
val = dobj.from_local_data((), np.full((), val))
else:
raise TypeError("val must be of type dobj.data_object")
if domain.shape != val.shape:
raise ValueError("mismatch between the shapes of val and domain")
self._domain = domain
......@@ -378,7 +381,9 @@ class Field(object):
Field
The complex conjugated field.
"""
return Field(self._domain, self._val.conjugate())
if np.issubdtype(self._val.dtype, np.complexfloating):
return Field(self._domain, self._val.conjugate())
return self
# ---General unary/contraction methods---
......@@ -607,6 +612,17 @@ class Field(object):
return False
return (self._val == other._val).all()
def extract(self, dom):
if dom is not self._domain:
raise ValueError("domain mismatch")
return self
def unite(self, other):
return self + other
def positive_tanh(self):
return 0.5*(1.+self.tanh())
for op in ["__add__", "__radd__",
"__sub__", "__rsub__",
......@@ -642,3 +658,11 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
"In-place operations are deliberately not supported")
return func2
setattr(Field, op, func(op))
for f in ["sqrt", "exp", "log", "tanh"]:
def func(f):
def func2(self):
fu = getattr(dobj, f)
return Field(domain=self._domain, val=fu(self.val))
return func2
setattr(Field, f, func(f))
......@@ -105,3 +105,15 @@ class MultiDomain(object):
for key, dom in zip(self._keys, self._domains):
res += key+": "+str(dom)+"\n"
return res
@staticmethod
def union(inp):
res = {}
for dom in inp:
for key, subdom in zip(dom._keys, dom._domains):
if key in res:
if res[key] is not subdom:
raise ValueError("domain mismatch")
else:
res[key] = subdom
return MultiDomain.make(res)
......@@ -32,7 +32,7 @@ class MultiField(object):
Parameters
----------
domain: MultiDomain
val: tuple containing Field or None entries
val: tuple containing Field entries
"""
if not isinstance(domain, MultiDomain):
raise TypeError("domain must be of type MultiDomain")
......@@ -44,8 +44,8 @@ class MultiField(object):
if isinstance(v, Field):
if v._domain is not d:
raise ValueError("domain mismatch")
elif v is not None:
raise TypeError("bad entry in val (must be Field or None)")
else:
raise TypeError("bad entry in val (must be Field)")
self._domain = domain
self._val = val
......@@ -54,8 +54,9 @@ class MultiField(object):
if domain is None:
domain = MultiDomain.make({key: v._domain
for key, v in dict.items()})
return MultiField(domain, tuple(dict[key] if key in dict else None
for key in domain.keys()))
res = tuple(dict[key] if key in dict else Field.full(dom, 0)
for key, dom in zip(domain.keys(), domain.domains()))
return MultiField(domain, res)
def to_dict(self):
return {key: val for key, val in zip(self._domain.keys(), self._val)}
......@@ -81,9 +82,7 @@ class MultiField(object):
# return {key: val.dtype for key, val in self._val.items()}
def _transform(self, op):
return MultiField(
self._domain,
tuple(op(v) if v is not None else None for v in self._val))
return MultiField(self._domain, tuple(op(v) for v in self._val))
@property
def real(self):
......@@ -111,8 +110,7 @@ class MultiField(object):
result = 0.
self._check_domain(x)
for v1, v2 in zip(self._val, x._val):
if v1 is not None and v2 is not None:
result += v1.vdot(v2)
result += v1.vdot(v2)
return result
# @staticmethod
......@@ -191,13 +189,13 @@ class MultiField(object):
def all(self):
for v in self._val:
if v is None or not v.all():
if not v.all():
return False
return True
def any(self):
for v in self._val:
if v is not None and v.any():
if v.any():
return True
return False
......@@ -215,45 +213,31 @@ class MultiField(object):
return False
return True
def extract(self, subset):
if isinstance(subset, MultiDomain):
return MultiField(subset,
tuple(self[key] for key in subset.keys()))
else:
return MultiField.from_dict({key: self[key] for key in subset})
for op in ["__add__", "__radd__"]:
def func(op):
def func2(self, other):
if isinstance(other, MultiField):
if self._domain is not other._domain:
raise ValueError("domain mismatch")
val = []
for v1, v2 in zip(self._val, other._val):
if v1 is not None:
val.append(v1 if v2 is None else (v1+v2))
else:
val.append(None if v2 is None else v2)
val = tuple(val)
else:
val = tuple(other if v1 is None else (v1+other)
for v1 in self._val)
return MultiField(self._domain, val)
return func2
setattr(MultiField, op, func(op))
for op in ["__mul__", "__rmul__"]:
def func(op):
def func2(self, other):
if isinstance(other, MultiField):
if self._domain is not other._domain:
raise ValueError("domain mismatch")
val = tuple(None if v1 is None or v2 is None else v1*v2
for v1, v2 in zip(self._val, other._val))
else:
val = tuple(None if v1 is None else (v1*other)
for v1 in self._val)
return MultiField(self._domain, val)
return func2
setattr(MultiField, op, func(op))
def unite(self, other):
return self.combine((self, other))
for op in ["__sub__", "__rsub__",
@staticmethod
def combine(fields):
res = {}
for f in fields:
for key in f.keys():
if key in res:
res[key] = res[key]+f[key]
else:
res[key] = f[key]
return MultiField.from_dict(res)
for op in ["__add__", "__radd__",
"__sub__", "__rsub__",
"__mul__", "__rmul__",
"__div__", "__rdiv__",
"__truediv__", "__rtruediv__",
"__floordiv__", "__rfloordiv__",
......@@ -281,3 +265,13 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
"In-place operations are deliberately not supported")
return func2
setattr(MultiField, op, func(op))
for f in ["sqrt", "exp", "log", "tanh"]:
def func(f):
def func2(self):
fu = getattr(dobj, f)
return MultiField(self.domain,
tuple(func2(val) for val in self.values()))
return func2
setattr(MultiField, f, func(f))
......@@ -42,7 +42,7 @@ class CentralZeroPadder(LinearOperator):
if i in axes:
slicer_fw = slice(0, (self._domain.shape[i]+1)//2)
slicer_bw = slice(-1, -1-(self._domain.shape[i]//2), -1)
slicer.append([slicer_fw, slicer_bw])
slicer.append((slicer_fw, slicer_bw))
self.slicer = list(itertools.product(*slicer))
for i in range(len(self.slicer)):
......@@ -50,7 +50,8 @@ class CentralZeroPadder(LinearOperator):
if j not in axes:
tmp = list(self.slicer[i])
tmp.insert(j, slice(None))
self.slicer[i] = tmp
self.slicer[i] = tuple(tmp)
self.slicer = tuple(self.slicer)
@property
def domain(self):
......
......@@ -45,7 +45,7 @@ class NullOperator(LinearOperator):
if isinstance(dom, DomainTuple):
return Field.full(dom, 0)
else:
return MultiField(dom, (None,)*len(dom))
return MultiField.full(dom, 0)
def apply(self, x, mode):
self._check_input(x, mode)
......
......@@ -34,11 +34,12 @@ from .multi.multi_field import MultiField
from .operators.diagonal_operator import DiagonalOperator
from .operators.power_distributor import PowerDistributor
__all__ = ['PS_field', 'power_analyze', 'create_power_operator',
'create_harmonic_smoothing_operator', 'from_random',
'full', 'from_global_data', 'from_local_data',
'makeDomain', 'sqrt', 'exp', 'log', 'tanh', 'conjugate',
'get_signal_variance', 'makeOp']
'makeDomain', 'sqrt', 'exp', 'log', 'tanh', 'positive_tanh',
'conjugate', 'get_signal_variance', 'makeOp', 'domain_union']
def PS_field(pspace, func):
......@@ -242,19 +243,25 @@ def makeOp(input):
input.domain, tuple(makeOp(val) for val in input.values()))
raise NotImplementedError
def domain_union(domains):
if isinstance(domains[0], DomainTuple):
if any(dom is not domains[0] for dom in domains[1:]):
raise ValueError("domain mismatch")
return domains[0]
return MultiDomain.union(domains)
# Arithmetic functions working on Fields
_current_module = sys.modules[__name__]
for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
for f in ["sqrt", "exp", "log", "tanh", "positive_tanh", "conjugate"]:
def func(f):
def func2(x):
if isinstance(x, MultiField):
return MultiField({key: func2(val) for key, val in x.items()})
elif isinstance(x, Field):
fu = getattr(dobj, f)
return Field(domain=x._domain, val=fu(x.val))
if isinstance(x, (Field, MultiField)):
return getattr(x, f)()
else:
return getattr(np, f)(x)
return func2
......
......@@ -33,8 +33,8 @@ __all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space",
"my_product", "frozendict", "special_add_at"]
def my_sum(terms):
return reduce(lambda x, y: x+y, terms)
def my_sum(iterable):
return reduce(lambda x, y: x+y, iterable)
def my_lincomb_simple(terms, factors):
......@@ -86,10 +86,10 @@ def get_slice_list(shape, axes):
[list(range(y)) for x, y in enumerate(shape) if x not in axes]
for index in product(*axes_iterables):
it_iter = iter(index)
slice_list = [
slice_list = tuple(
next(it_iter)
if axis else slice(None, None) for axis in axes_select
]
)
yield slice_list
else:
yield [slice(None, None)]
......@@ -159,7 +159,7 @@ class _DocStringInheritor(type):
if doc:
clsdict['__doc__'] = doc
break
for attr, attribute in list(clsdict.items()):
for attr, attribute in clsdict.items():
if not attribute.__doc__:
for mro_cls in (mro_cls for base in bases
for mro_cls in base.mro()
......@@ -223,7 +223,7 @@ def hartley(a, axes=None):
axes = tuple(range(tmp.ndim))
lastaxis = axes[-1]
ntmplast = tmp.shape[lastaxis]
slice1 = [slice(None)]*lastaxis + [slice(0, ntmplast)]
slice1 = (slice(None),)*lastaxis + (slice(0, ntmplast),)
np.add(tmp.real, tmp.imag, out=res[slice1])
def _fill_upper_half(tmp, res, axes):
......@@ -236,9 +236,11 @@ def hartley(a, axes=None):
for i in axes[:-1]:
slice1[i] = slice(1, None)
slice2[i] = slice(None, 0, -1)
slice1 = tuple(slice1)
slice2 = tuple(slice2)
np.subtract(tmp[slice2].real, tmp[slice2].imag, out=res[slice1])
for i, ax in enumerate(axes[:-1]):
dim1 = [slice(None)]*ax + [slice(0, 1)]
dim1 = (slice(None),)*ax + (slice(0, 1),)
axes2 = axes[:i] + axes[i+1:]
_fill_upper_half(tmp[dim1], res[dim1], axes2)
......
......@@ -22,7 +22,7 @@ from test.common import expand
import nifty5 as ift
import numpy as np
from nose.plugins.skip import SkipTest
from unittest import SkipTest
from numpy.testing import assert_allclose, assert_equal
IC = ift.GradientNormController(tol_abs_gradnorm=1e-5, iteration_limit=1000)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment