From 1107973c64a09f67c6000ea2cf053840a34d0610 Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Sat, 17 Feb 2018 12:42:42 +0100 Subject: [PATCH] more docs, various tweaks --- demos/wiener_filter_via_hamiltonian.py | 3 --- nifty4/data_objects/distributed_do.py | 18 +++++++++++--- nifty4/data_objects/numpy_do.py | 16 ++++++++++-- nifty4/domain_tuple.py | 33 +++++++++++++++++++++++++ nifty4/domains/power_space.py | 2 +- nifty4/field.py | 18 ++++++++------ nifty4/library/critical_power_energy.py | 1 + nifty4/sugar.py | 3 +-- 8 files changed, 75 insertions(+), 19 deletions(-) diff --git a/demos/wiener_filter_via_hamiltonian.py b/demos/wiener_filter_via_hamiltonian.py index 57a49d11e..1eedca8d4 100644 --- a/demos/wiener_filter_via_hamiltonian.py +++ b/demos/wiener_filter_via_hamiltonian.py @@ -76,9 +76,6 @@ if __name__ == "__main__": ift.plot(ht(m), name="reconstruction.png", **plotdict) # Sample uncertainty map - sample_variance = ift.Field.zeros(s_space) - sample_mean = ift.Field.zeros(s_space) - mean, variance = ift.probe_with_posterior_samples(curv, ht, 50) ift.plot(variance, name="posterior_variance.png", **plotdict2) ift.plot(mean+ht(m), name="posterior_mean.png", **plotdict) diff --git a/nifty4/data_objects/distributed_do.py b/nifty4/data_objects/distributed_do.py index 8020a7cf2..aebe19fce 100644 --- a/nifty4/data_objects/distributed_do.py +++ b/nifty4/data_objects/distributed_do.py @@ -323,10 +323,20 @@ def sqrt(a, out=None): return _math_helper(a, np.sqrt, out) -def from_object(object, dtype=None, copy=True): - return data_object(object._shape, np.array(object._data, dtype=dtype, - copy=copy), - distaxis=object._distaxis) +def from_object(object, dtype, copy, set_locked): + if dtype is None: + dtype = object.dtype + dtypes_equal = dtype == object.dtype + if set_locked and dtypes_equal and locked(object): + return object + if not dtypes_equal and not copy: + raise ValueError("cannot change data type without copying") + if set_locked and not copy: + raise ValueError("cannot lock object without copying") + data = np.array(object._data, dtype=dtype, copy=copy) + if set_locked: + lock(data) + return data_object(object._shape, data, distaxis=object._distaxis) # This function draws all random numbers on all tasks, to produce the same diff --git a/nifty4/data_objects/numpy_do.py b/nifty4/data_objects/numpy_do.py index 5683a0ff8..9d13bffdc 100644 --- a/nifty4/data_objects/numpy_do.py +++ b/nifty4/data_objects/numpy_do.py @@ -38,8 +38,20 @@ def mprint(*args): print(*args) -def from_object(object, dtype=None, copy=True): - return np.array(object, dtype=dtype, copy=copy) +def from_object(object, dtype, copy, set_locked): + if dtype is None: + dtype = object.dtype + dtypes_equal = dtype == object.dtype + if set_locked and dtypes_equal and locked(object): + return object + if not dtypes_equal and not copy: + raise ValueError("cannot change data type without copying") + if set_locked and not copy: + raise ValueError("cannot lock object without copying") + res = np.array(object, dtype=dtype, copy=copy) + if set_locked: + lock(res) + return res def from_random(random_type, shape, dtype=np.float64, **kwargs): diff --git a/nifty4/domain_tuple.py b/nifty4/domain_tuple.py index ef8c73c26..901624c00 100644 --- a/nifty4/domain_tuple.py +++ b/nifty4/domain_tuple.py @@ -21,6 +21,17 @@ from .domains.domain import Domain class DomainTuple(object): + """Ordered sequence of Domain objects. + + This class holds a set of :class:`Domain` objects, which together form the + space on which a :class:`Field` is defined. + + Notes + ----- + + DomainTuples should never be created using the constructor, but rather + via the factory function :attr:`make`! + """ _tupleCache = {} def __init__(self, domain): @@ -44,6 +55,18 @@ class DomainTuple(object): @staticmethod def make(domain): + """Returns a DomainTuple matching `domain`. + + This function checks whether a matching DomainTuple already exists. + If yes, this object is returned, otherwise a new DomainTuple object + is created and returned. + + Parameters + ---------- + domain : Domain or tuple of Domain or DomainTuple + The geometrical structure for which the DomainTuple shall be + obtained. + """ if isinstance(domain, DomainTuple): return domain domain = DomainTuple._parse_domain(domain) @@ -75,14 +98,24 @@ class DomainTuple(object): @property def shape(self): + """tuple of int: number of pixels along each axis + + The shape of the array-like object required to store information + living on the DomainTuple. + """ return self._shape @property def size(self): + """int : total number of pixels. + + Equivalent to the products over all entries in the object's shape. + """ return self._size @property def axes(self): + """tuple of tuple of int : shapes of the underlying domains""" return self._axtuple def __len__(self): diff --git a/nifty4/domains/power_space.py b/nifty4/domains/power_space.py index fd49ba1de..8f5f4b9e5 100644 --- a/nifty4/domains/power_space.py +++ b/nifty4/domains/power_space.py @@ -239,5 +239,5 @@ class PowerSpace(StructuredDomain): @property def k_lengths(self): - """numpy.ndarray(float) : sorted array of all k-vector lengths.""" + """numpy.ndarray(float) : k-vector length for each bin.""" return self._k_lengths diff --git a/nifty4/field.py b/nifty4/field.py index a9391e681..0403b0431 100644 --- a/nifty4/field.py +++ b/nifty4/field.py @@ -45,23 +45,26 @@ class Field(object): dtype : type A numpy.type. Most common are float and complex. - copy: bool + copy : bool """ - def __init__(self, domain=None, val=None, dtype=None, copy=False): + def __init__(self, domain=None, val=None, dtype=None, copy=False, + locked=False): self._domain = self._infer_domain(domain=domain, val=val) dtype = self._infer_dtype(dtype=dtype, val=val) if isinstance(val, Field): if self._domain != val._domain: raise ValueError("Domain mismatch") - self._val = dobj.from_object(val.val, dtype=dtype, copy=copy) + self._val = dobj.from_object(val.val, dtype=dtype, copy=copy, + set_locked=locked) elif (np.isscalar(val)): self._val = dobj.full(self._domain.shape, dtype=dtype, fill_value=val) elif isinstance(val, dobj.data_object): if self._domain.shape == val.shape: - self._val = dobj.from_object(val, dtype=dtype, copy=copy) + self._val = dobj.from_object(val, dtype=dtype, copy=copy, + set_locked=locked) else: raise ValueError("Shape mismatch") elif val is None: @@ -69,6 +72,9 @@ class Field(object): else: raise TypeError("unknown source type") + if locked: + dobj.lock(self._val) + @staticmethod def full(domain, val, dtype=None): if not np.isscalar(val): @@ -251,9 +257,7 @@ class Field(object): """ if self.locked: return self - res = Field(val=self, copy=True) - res.lock() - return res + return Field(val=self, copy=True, locked=True) def scalar_weight(self, spaces=None): if np.isscalar(spaces): diff --git a/nifty4/library/critical_power_energy.py b/nifty4/library/critical_power_energy.py index 5ec4d32f9..3fc6cca4f 100644 --- a/nifty4/library/critical_power_energy.py +++ b/nifty4/library/critical_power_energy.py @@ -90,6 +90,7 @@ class CriticalPowerEnergy(Energy): Dist = PowerDistributor(target=self.m.domain, power_space=self.position.domain[0]) if self.D is not None: + # MR FIXME: we should use stuff from probing.utils for this w = Field.zeros(self.position.domain, dtype=self.m.dtype) for i in range(self.samples): sample = self.D.draw_sample() + self.m diff --git a/nifty4/sugar.py b/nifty4/sugar.py index a2f665efc..f8cccf0e4 100644 --- a/nifty4/sugar.py +++ b/nifty4/sugar.py @@ -239,8 +239,7 @@ def create_power_operator(domain, power_spectrum, space=None, dtype=None): space = int(space) return DiagonalOperator( create_power_field(domain[space], power_spectrum, dtype), - domain=domain, - spaces=space) + domain=domain, spaces=space) def create_composed_ht_operator(domain, codomain=None): -- GitLab