diff --git a/nifty5/data_objects/distributed_do.py b/nifty5/data_objects/distributed_do.py index f36e1984b6d7cfa3beaac6df7d18566da2224848..49a85445d4d3a8d848081a4bd88e8dd461128c48 100644 --- a/nifty5/data_objects/distributed_do.py +++ b/nifty5/data_objects/distributed_do.py @@ -61,6 +61,8 @@ class data_object(object): self._shape = tuple(shape) if len(self._shape) == 0: distaxis = -1 + if not isinstance(data, np.ndarray): + data = np.full((), data) self._distaxis = distaxis self._data = data if local_shape(self._shape, self._distaxis) != self._data.shape: @@ -262,7 +264,7 @@ def empty_like(a, dtype=None): def vdot(a, b): tmp = np.array(np.vdot(a._data, b._data)) - if a._distaxis==-1: + if a._distaxis == -1: return tmp[()] res = np.empty((), dtype=tmp.dtype) _comm.Allreduce(tmp, res, MPI.SUM) @@ -311,7 +313,7 @@ def from_object(object, dtype, copy, set_locked): # algorithm. def from_random(random_type, shape, dtype=np.float64, **kwargs): generator_function = getattr(Random, random_type) - if shape == (): + if len(shape) == 0: ldat = generator_function(dtype=dtype, shape=shape, **kwargs) ldat = _comm.bcast(ldat) return from_local_data(shape, ldat, distaxis=-1) @@ -460,15 +462,16 @@ def redistribute(arr, dist=None, nodist=None): rbuf = rbuf.reshape(local_shape(arr.shape, dist)) arrnew = from_local_data(arr.shape, rbuf, distaxis=dist) else: - arrnew = empty(arr.shape, dtype=arr.dtype, distaxis=dist) + arrnew = np.empty(local_shape(arr.shape, dist), dtype=arr.dtype) rslice = [slice(None)]*arr._data.ndim ofs = 0 for i in range(ntask): lo, hi = _shareRange(arr.shape[arr._distaxis], ntask, i) rslice[arr._distaxis] = slice(lo, hi) sz = rsz[i]//arr._data.itemsize - arrnew._data[rslice].flat = rbuf[ofs:ofs+sz] + arrnew[rslice].flat = rbuf[ofs:ofs+sz] ofs += sz + arrnew = from_local_data(arr.shape, arrnew, distaxis=dist) return arrnew @@ -497,15 +500,15 @@ def transpose(arr): r_msg = [rbuf, (rsz, rdisp), MPI.BYTE] _comm.Alltoallv(s_msg, r_msg) del sbuf # free memory - arrnew = empty((arr.shape[1], arr.shape[0]), dtype=arr.dtype, distaxis=0) - ofs = 0 sz2 = _shareSize(arr.shape[1], ntask, rank) + arrnew = np.empty((sz2, arr.shape[0]), dtype=arr.dtype) + ofs = 0 for i in range(ntask): lo, hi = _shareRange(arr.shape[0], ntask, i) sz = rsz[i]//arr._data.itemsize - arrnew._data[:, lo:hi] = rbuf[ofs:ofs+sz].reshape(hi-lo, sz2).T + arrnew[:, lo:hi] = rbuf[ofs:ofs+sz].reshape(hi-lo, sz2).T ofs += sz - return arrnew + return from_local_data((arr.shape[1], arr.shape[0]), arrnew, 0) def default_distaxis(): diff --git a/nifty5/domain_tuple.py b/nifty5/domain_tuple.py index 287192d9fe8b7a7972877c20c23a514f8ea1fb13..64d5b53adc9e303cc8f3641355fde20ab10d9e5f 100644 --- a/nifty5/domain_tuple.py +++ b/nifty5/domain_tuple.py @@ -37,6 +37,7 @@ class DomainTuple(object): via the factory function :attr:`make`! """ _tupleCache = {} + _scalarDomain = None def __init__(self, domain, _callingfrommake=False): if not _callingfrommake: @@ -150,3 +151,9 @@ class DomainTuple(object): for i in self: res += "\n" + str(i) return res + + @staticmethod + def scalar_domain(): + if DomainTuple._scalarDomain is None: + DomainTuple._scalarDomain = DomainTuple.make(()) + return DomainTuple._scalarDomain diff --git a/nifty5/field.py b/nifty5/field.py index 9865d7137c04fe21266abae62edbd0ff4b992b30..56163b1ea7be7ec809b84dd02de3ef7cb1d33e3d 100644 --- a/nifty5/field.py +++ b/nifty5/field.py @@ -50,7 +50,10 @@ class Field(object): if not isinstance(domain, DomainTuple): raise TypeError("domain must be of type DomainTuple") if not isinstance(val, dobj.data_object): - raise TypeError("val must be of type dobj.data_object") + if np.isscalar(val): + val = dobj.from_local_data((), np.full((), val)) + else: + raise TypeError("val must be of type dobj.data_object") if domain.shape != val.shape: raise ValueError("mismatch between the shapes of val and domain") self._domain = domain @@ -378,7 +381,9 @@ class Field(object): Field The complex conjugated field. """ - return Field(self._domain, self._val.conjugate()) + if np.issubdtype(self._val.dtype, np.complexfloating): + return Field(self._domain, self._val.conjugate()) + return self # ---General unary/contraction methods--- @@ -607,6 +612,17 @@ class Field(object): return False return (self._val == other._val).all() + def extract(self, dom): + if dom is not self._domain: + raise ValueError("domain mismatch") + return self + + def unite(self, other): + return self + other + + def positive_tanh(self): + return 0.5*(1.+self.tanh()) + for op in ["__add__", "__radd__", "__sub__", "__rsub__", @@ -642,3 +658,11 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__", "In-place operations are deliberately not supported") return func2 setattr(Field, op, func(op)) + +for f in ["sqrt", "exp", "log", "tanh"]: + def func(f): + def func2(self): + fu = getattr(dobj, f) + return Field(domain=self._domain, val=fu(self.val)) + return func2 + setattr(Field, f, func(f)) diff --git a/nifty5/multi/multi_domain.py b/nifty5/multi/multi_domain.py index 715e51570dd98dcb2d4511ad3847470babcb2a3b..7876184af1c0fadbd77f9e4117bc0c78f3124082 100644 --- a/nifty5/multi/multi_domain.py +++ b/nifty5/multi/multi_domain.py @@ -105,3 +105,15 @@ class MultiDomain(object): for key, dom in zip(self._keys, self._domains): res += key+": "+str(dom)+"\n" return res + + @staticmethod + def union(inp): + res = {} + for dom in inp: + for key, subdom in zip(dom._keys, dom._domains): + if key in res: + if res[key] is not subdom: + raise ValueError("domain mismatch") + else: + res[key] = subdom + return MultiDomain.make(res) diff --git a/nifty5/multi/multi_field.py b/nifty5/multi/multi_field.py index caee498e51865adf33ce9a3ea53960f8986d30f0..652bd1700a5944811b3a4a6c946c48345b5c6aca 100644 --- a/nifty5/multi/multi_field.py +++ b/nifty5/multi/multi_field.py @@ -32,7 +32,7 @@ class MultiField(object): Parameters ---------- domain: MultiDomain - val: tuple containing Field or None entries + val: tuple containing Field entries """ if not isinstance(domain, MultiDomain): raise TypeError("domain must be of type MultiDomain") @@ -44,8 +44,8 @@ class MultiField(object): if isinstance(v, Field): if v._domain is not d: raise ValueError("domain mismatch") - elif v is not None: - raise TypeError("bad entry in val (must be Field or None)") + else: + raise TypeError("bad entry in val (must be Field)") self._domain = domain self._val = val @@ -54,8 +54,9 @@ class MultiField(object): if domain is None: domain = MultiDomain.make({key: v._domain for key, v in dict.items()}) - return MultiField(domain, tuple(dict[key] if key in dict else None - for key in domain.keys())) + res = tuple(dict[key] if key in dict else Field.full(dom, 0) + for key, dom in zip(domain.keys(), domain.domains())) + return MultiField(domain, res) def to_dict(self): return {key: val for key, val in zip(self._domain.keys(), self._val)} @@ -81,9 +82,7 @@ class MultiField(object): # return {key: val.dtype for key, val in self._val.items()} def _transform(self, op): - return MultiField( - self._domain, - tuple(op(v) if v is not None else None for v in self._val)) + return MultiField(self._domain, tuple(op(v) for v in self._val)) @property def real(self): @@ -111,8 +110,7 @@ class MultiField(object): result = 0. self._check_domain(x) for v1, v2 in zip(self._val, x._val): - if v1 is not None and v2 is not None: - result += v1.vdot(v2) + result += v1.vdot(v2) return result # @staticmethod @@ -191,13 +189,13 @@ class MultiField(object): def all(self): for v in self._val: - if v is None or not v.all(): + if not v.all(): return False return True def any(self): for v in self._val: - if v is not None and v.any(): + if v.any(): return True return False @@ -215,45 +213,31 @@ class MultiField(object): return False return True + def extract(self, subset): + if isinstance(subset, MultiDomain): + return MultiField(subset, + tuple(self[key] for key in subset.keys())) + else: + return MultiField.from_dict({key: self[key] for key in subset}) -for op in ["__add__", "__radd__"]: - def func(op): - def func2(self, other): - if isinstance(other, MultiField): - if self._domain is not other._domain: - raise ValueError("domain mismatch") - val = [] - for v1, v2 in zip(self._val, other._val): - if v1 is not None: - val.append(v1 if v2 is None else (v1+v2)) - else: - val.append(None if v2 is None else v2) - val = tuple(val) - else: - val = tuple(other if v1 is None else (v1+other) - for v1 in self._val) - return MultiField(self._domain, val) - return func2 - setattr(MultiField, op, func(op)) - - -for op in ["__mul__", "__rmul__"]: - def func(op): - def func2(self, other): - if isinstance(other, MultiField): - if self._domain is not other._domain: - raise ValueError("domain mismatch") - val = tuple(None if v1 is None or v2 is None else v1*v2 - for v1, v2 in zip(self._val, other._val)) - else: - val = tuple(None if v1 is None else (v1*other) - for v1 in self._val) - return MultiField(self._domain, val) - return func2 - setattr(MultiField, op, func(op)) - + def unite(self, other): + return self.combine((self, other)) -for op in ["__sub__", "__rsub__", + @staticmethod + def combine(fields): + res = {} + for f in fields: + for key in f.keys(): + if key in res: + res[key] = res[key]+f[key] + else: + res[key] = f[key] + return MultiField.from_dict(res) + + +for op in ["__add__", "__radd__", + "__sub__", "__rsub__", + "__mul__", "__rmul__", "__div__", "__rdiv__", "__truediv__", "__rtruediv__", "__floordiv__", "__rfloordiv__", @@ -281,3 +265,13 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__", "In-place operations are deliberately not supported") return func2 setattr(MultiField, op, func(op)) + + +for f in ["sqrt", "exp", "log", "tanh"]: + def func(f): + def func2(self): + fu = getattr(dobj, f) + return MultiField(self.domain, + tuple(func2(val) for val in self.values())) + return func2 + setattr(MultiField, f, func(f)) diff --git a/nifty5/operators/central_zero_padder.py b/nifty5/operators/central_zero_padder.py index 7ba66d2adb28611f5cdede0cb0517b5831e125de..2a7f0649caaec908fe7ddc306f2f2f25519ad975 100644 --- a/nifty5/operators/central_zero_padder.py +++ b/nifty5/operators/central_zero_padder.py @@ -42,7 +42,7 @@ class CentralZeroPadder(LinearOperator): if i in axes: slicer_fw = slice(0, (self._domain.shape[i]+1)//2) slicer_bw = slice(-1, -1-(self._domain.shape[i]//2), -1) - slicer.append([slicer_fw, slicer_bw]) + slicer.append((slicer_fw, slicer_bw)) self.slicer = list(itertools.product(*slicer)) for i in range(len(self.slicer)): @@ -50,7 +50,8 @@ class CentralZeroPadder(LinearOperator): if j not in axes: tmp = list(self.slicer[i]) tmp.insert(j, slice(None)) - self.slicer[i] = tmp + self.slicer[i] = tuple(tmp) + self.slicer = tuple(self.slicer) @property def domain(self): diff --git a/nifty5/operators/null_operator.py b/nifty5/operators/null_operator.py index 41505733fe33e7b6f1a8832eb7433e1e681ee3d9..733184a1af4bb6f58e4939787be296ef88a4bb12 100644 --- a/nifty5/operators/null_operator.py +++ b/nifty5/operators/null_operator.py @@ -45,7 +45,7 @@ class NullOperator(LinearOperator): if isinstance(dom, DomainTuple): return Field.full(dom, 0) else: - return MultiField(dom, (None,)*len(dom)) + return MultiField.full(dom, 0) def apply(self, x, mode): self._check_input(x, mode) diff --git a/nifty5/sugar.py b/nifty5/sugar.py index 942d11a4e71e7c9eba4f85ef6f7c15ab86aa14e8..5f50fc6351b24454f852b06553c163a050f75361 100644 --- a/nifty5/sugar.py +++ b/nifty5/sugar.py @@ -34,11 +34,12 @@ from .multi.multi_field import MultiField from .operators.diagonal_operator import DiagonalOperator from .operators.power_distributor import PowerDistributor + __all__ = ['PS_field', 'power_analyze', 'create_power_operator', 'create_harmonic_smoothing_operator', 'from_random', 'full', 'from_global_data', 'from_local_data', - 'makeDomain', 'sqrt', 'exp', 'log', 'tanh', 'conjugate', - 'get_signal_variance', 'makeOp'] + 'makeDomain', 'sqrt', 'exp', 'log', 'tanh', 'positive_tanh', + 'conjugate', 'get_signal_variance', 'makeOp', 'domain_union'] def PS_field(pspace, func): @@ -242,19 +243,26 @@ def makeOp(input): input.domain, tuple(makeOp(val) for val in input.values())) raise NotImplementedError + +def domain_union(domains): + if isinstance(domains[0], DomainTuple): + if any(dom is not domains[0] for dom in domains[1:]): + raise ValueError("domain mismatch") + return domains[0] + return MultiDomain.union(domains) + + # Arithmetic functions working on Fields _current_module = sys.modules[__name__] -for f in ["sqrt", "exp", "log", "tanh", "conjugate"]: +for f in ["sqrt", "exp", "log", "tanh", "positive_tanh", "conjugate"]: def func(f): def func2(x): - if isinstance(x, MultiField): - return MultiField({key: func2(val) for key, val in x.items()}) - elif isinstance(x, Field): - fu = getattr(dobj, f) - return Field(domain=x._domain, val=fu(x.val)) + from .linearization import Linearization + if isinstance(x, (Field, MultiField, Linearization)): + return getattr(x, f)() else: return getattr(np, f)(x) return func2 diff --git a/nifty5/utilities.py b/nifty5/utilities.py index 8538b409e5c8235e977a51498a4526d3e18b1829..2ae4364c3a6436c85e9fc7a079b682426707893e 100644 --- a/nifty5/utilities.py +++ b/nifty5/utilities.py @@ -33,8 +33,8 @@ __all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space", "my_product", "frozendict", "special_add_at"] -def my_sum(terms): - return reduce(lambda x, y: x+y, terms) +def my_sum(iterable): + return reduce(lambda x, y: x+y, iterable) def my_lincomb_simple(terms, factors): @@ -86,10 +86,10 @@ def get_slice_list(shape, axes): [list(range(y)) for x, y in enumerate(shape) if x not in axes] for index in product(*axes_iterables): it_iter = iter(index) - slice_list = [ + slice_list = tuple( next(it_iter) if axis else slice(None, None) for axis in axes_select - ] + ) yield slice_list else: yield [slice(None, None)] @@ -159,7 +159,7 @@ class _DocStringInheritor(type): if doc: clsdict['__doc__'] = doc break - for attr, attribute in list(clsdict.items()): + for attr, attribute in clsdict.items(): if not attribute.__doc__: for mro_cls in (mro_cls for base in bases for mro_cls in base.mro() @@ -223,7 +223,7 @@ def hartley(a, axes=None): axes = tuple(range(tmp.ndim)) lastaxis = axes[-1] ntmplast = tmp.shape[lastaxis] - slice1 = [slice(None)]*lastaxis + [slice(0, ntmplast)] + slice1 = (slice(None),)*lastaxis + (slice(0, ntmplast),) np.add(tmp.real, tmp.imag, out=res[slice1]) def _fill_upper_half(tmp, res, axes): @@ -236,9 +236,11 @@ def hartley(a, axes=None): for i in axes[:-1]: slice1[i] = slice(1, None) slice2[i] = slice(None, 0, -1) + slice1 = tuple(slice1) + slice2 = tuple(slice2) np.subtract(tmp[slice2].real, tmp[slice2].imag, out=res[slice1]) for i, ax in enumerate(axes[:-1]): - dim1 = [slice(None)]*ax + [slice(0, 1)] + dim1 = (slice(None),)*ax + (slice(0, 1),) axes2 = axes[:i] + axes[i+1:] _fill_upper_half(tmp[dim1], res[dim1], axes2) diff --git a/test/test_minimization/test_minimizers.py b/test/test_minimization/test_minimizers.py index e0b3601d1aa46546843f9773fae4863beff215c4..357c583c11dc145d648909ecc57f39835b0d3b33 100644 --- a/test/test_minimization/test_minimizers.py +++ b/test/test_minimization/test_minimizers.py @@ -22,7 +22,7 @@ from test.common import expand import nifty5 as ift import numpy as np -from nose.plugins.skip import SkipTest +from unittest import SkipTest from numpy.testing import assert_allclose, assert_equal IC = ift.GradientNormController(tol_abs_gradnorm=1e-5, iteration_limit=1000)