Commit 1107973c authored by Martin Reinecke's avatar Martin Reinecke
Browse files

more docs, various tweaks

parent c9444176
Pipeline #25049 failed with stages
in 5 minutes and 1 second
...@@ -76,9 +76,6 @@ if __name__ == "__main__": ...@@ -76,9 +76,6 @@ if __name__ == "__main__":
ift.plot(ht(m), name="reconstruction.png", **plotdict) ift.plot(ht(m), name="reconstruction.png", **plotdict)
# Sample uncertainty map # Sample uncertainty map
sample_variance = ift.Field.zeros(s_space)
sample_mean = ift.Field.zeros(s_space)
mean, variance = ift.probe_with_posterior_samples(curv, ht, 50) mean, variance = ift.probe_with_posterior_samples(curv, ht, 50)
ift.plot(variance, name="posterior_variance.png", **plotdict2) ift.plot(variance, name="posterior_variance.png", **plotdict2)
ift.plot(mean+ht(m), name="posterior_mean.png", **plotdict) ift.plot(mean+ht(m), name="posterior_mean.png", **plotdict)
......
...@@ -323,10 +323,20 @@ def sqrt(a, out=None): ...@@ -323,10 +323,20 @@ def sqrt(a, out=None):
return _math_helper(a, np.sqrt, out) return _math_helper(a, np.sqrt, out)
def from_object(object, dtype=None, copy=True): def from_object(object, dtype, copy, set_locked):
return data_object(object._shape, np.array(object._data, dtype=dtype, if dtype is None:
copy=copy), dtype = object.dtype
distaxis=object._distaxis) dtypes_equal = dtype == object.dtype
if set_locked and dtypes_equal and locked(object):
return object
if not dtypes_equal and not copy:
raise ValueError("cannot change data type without copying")
if set_locked and not copy:
raise ValueError("cannot lock object without copying")
data = np.array(object._data, dtype=dtype, copy=copy)
if set_locked:
lock(data)
return data_object(object._shape, data, distaxis=object._distaxis)
# This function draws all random numbers on all tasks, to produce the same # This function draws all random numbers on all tasks, to produce the same
......
...@@ -38,8 +38,20 @@ def mprint(*args): ...@@ -38,8 +38,20 @@ def mprint(*args):
print(*args) print(*args)
def from_object(object, dtype=None, copy=True): def from_object(object, dtype, copy, set_locked):
return np.array(object, dtype=dtype, copy=copy) if dtype is None:
dtype = object.dtype
dtypes_equal = dtype == object.dtype
if set_locked and dtypes_equal and locked(object):
return object
if not dtypes_equal and not copy:
raise ValueError("cannot change data type without copying")
if set_locked and not copy:
raise ValueError("cannot lock object without copying")
res = np.array(object, dtype=dtype, copy=copy)
if set_locked:
lock(res)
return res
def from_random(random_type, shape, dtype=np.float64, **kwargs): def from_random(random_type, shape, dtype=np.float64, **kwargs):
......
...@@ -21,6 +21,17 @@ from .domains.domain import Domain ...@@ -21,6 +21,17 @@ from .domains.domain import Domain
class DomainTuple(object): class DomainTuple(object):
"""Ordered sequence of Domain objects.
This class holds a set of :class:`Domain` objects, which together form the
space on which a :class:`Field` is defined.
Notes
-----
DomainTuples should never be created using the constructor, but rather
via the factory function :attr:`make`!
"""
_tupleCache = {} _tupleCache = {}
def __init__(self, domain): def __init__(self, domain):
...@@ -44,6 +55,18 @@ class DomainTuple(object): ...@@ -44,6 +55,18 @@ class DomainTuple(object):
@staticmethod @staticmethod
def make(domain): def make(domain):
"""Returns a DomainTuple matching `domain`.
This function checks whether a matching DomainTuple already exists.
If yes, this object is returned, otherwise a new DomainTuple object
is created and returned.
Parameters
----------
domain : Domain or tuple of Domain or DomainTuple
The geometrical structure for which the DomainTuple shall be
obtained.
"""
if isinstance(domain, DomainTuple): if isinstance(domain, DomainTuple):
return domain return domain
domain = DomainTuple._parse_domain(domain) domain = DomainTuple._parse_domain(domain)
...@@ -75,14 +98,24 @@ class DomainTuple(object): ...@@ -75,14 +98,24 @@ class DomainTuple(object):
@property @property
def shape(self): def shape(self):
"""tuple of int: number of pixels along each axis
The shape of the array-like object required to store information
living on the DomainTuple.
"""
return self._shape return self._shape
@property @property
def size(self): def size(self):
"""int : total number of pixels.
Equivalent to the products over all entries in the object's shape.
"""
return self._size return self._size
@property @property
def axes(self): def axes(self):
"""tuple of tuple of int : shapes of the underlying domains"""
return self._axtuple return self._axtuple
def __len__(self): def __len__(self):
......
...@@ -239,5 +239,5 @@ class PowerSpace(StructuredDomain): ...@@ -239,5 +239,5 @@ class PowerSpace(StructuredDomain):
@property @property
def k_lengths(self): def k_lengths(self):
"""numpy.ndarray(float) : sorted array of all k-vector lengths.""" """numpy.ndarray(float) : k-vector length for each bin."""
return self._k_lengths return self._k_lengths
...@@ -45,23 +45,26 @@ class Field(object): ...@@ -45,23 +45,26 @@ class Field(object):
dtype : type dtype : type
A numpy.type. Most common are float and complex. A numpy.type. Most common are float and complex.
copy: bool copy : bool
""" """
def __init__(self, domain=None, val=None, dtype=None, copy=False): def __init__(self, domain=None, val=None, dtype=None, copy=False,
locked=False):
self._domain = self._infer_domain(domain=domain, val=val) self._domain = self._infer_domain(domain=domain, val=val)
dtype = self._infer_dtype(dtype=dtype, val=val) dtype = self._infer_dtype(dtype=dtype, val=val)
if isinstance(val, Field): if isinstance(val, Field):
if self._domain != val._domain: if self._domain != val._domain:
raise ValueError("Domain mismatch") raise ValueError("Domain mismatch")
self._val = dobj.from_object(val.val, dtype=dtype, copy=copy) self._val = dobj.from_object(val.val, dtype=dtype, copy=copy,
set_locked=locked)
elif (np.isscalar(val)): elif (np.isscalar(val)):
self._val = dobj.full(self._domain.shape, dtype=dtype, self._val = dobj.full(self._domain.shape, dtype=dtype,
fill_value=val) fill_value=val)
elif isinstance(val, dobj.data_object): elif isinstance(val, dobj.data_object):
if self._domain.shape == val.shape: if self._domain.shape == val.shape:
self._val = dobj.from_object(val, dtype=dtype, copy=copy) self._val = dobj.from_object(val, dtype=dtype, copy=copy,
set_locked=locked)
else: else:
raise ValueError("Shape mismatch") raise ValueError("Shape mismatch")
elif val is None: elif val is None:
...@@ -69,6 +72,9 @@ class Field(object): ...@@ -69,6 +72,9 @@ class Field(object):
else: else:
raise TypeError("unknown source type") raise TypeError("unknown source type")
if locked:
dobj.lock(self._val)
@staticmethod @staticmethod
def full(domain, val, dtype=None): def full(domain, val, dtype=None):
if not np.isscalar(val): if not np.isscalar(val):
...@@ -251,9 +257,7 @@ class Field(object): ...@@ -251,9 +257,7 @@ class Field(object):
""" """
if self.locked: if self.locked:
return self return self
res = Field(val=self, copy=True) return Field(val=self, copy=True, locked=True)
res.lock()
return res
def scalar_weight(self, spaces=None): def scalar_weight(self, spaces=None):
if np.isscalar(spaces): if np.isscalar(spaces):
......
...@@ -90,6 +90,7 @@ class CriticalPowerEnergy(Energy): ...@@ -90,6 +90,7 @@ class CriticalPowerEnergy(Energy):
Dist = PowerDistributor(target=self.m.domain, Dist = PowerDistributor(target=self.m.domain,
power_space=self.position.domain[0]) power_space=self.position.domain[0])
if self.D is not None: if self.D is not None:
# MR FIXME: we should use stuff from probing.utils for this
w = Field.zeros(self.position.domain, dtype=self.m.dtype) w = Field.zeros(self.position.domain, dtype=self.m.dtype)
for i in range(self.samples): for i in range(self.samples):
sample = self.D.draw_sample() + self.m sample = self.D.draw_sample() + self.m
......
...@@ -239,8 +239,7 @@ def create_power_operator(domain, power_spectrum, space=None, dtype=None): ...@@ -239,8 +239,7 @@ def create_power_operator(domain, power_spectrum, space=None, dtype=None):
space = int(space) space = int(space)
return DiagonalOperator( return DiagonalOperator(
create_power_field(domain[space], power_spectrum, dtype), create_power_field(domain[space], power_spectrum, dtype),
domain=domain, domain=domain, spaces=space)
spaces=space)
def create_composed_ht_operator(domain, codomain=None): def create_composed_ht_operator(domain, codomain=None):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment