Commit 1107973c authored by Martin Reinecke's avatar Martin Reinecke

more docs, various tweaks

parent c9444176
Pipeline #25049 failed with stages
in 5 minutes and 1 second
......@@ -76,9 +76,6 @@ if __name__ == "__main__":
ift.plot(ht(m), name="reconstruction.png", **plotdict)
# Sample uncertainty map
sample_variance = ift.Field.zeros(s_space)
sample_mean = ift.Field.zeros(s_space)
mean, variance = ift.probe_with_posterior_samples(curv, ht, 50)
ift.plot(variance, name="posterior_variance.png", **plotdict2)
ift.plot(mean+ht(m), name="posterior_mean.png", **plotdict)
......
......@@ -323,10 +323,20 @@ def sqrt(a, out=None):
return _math_helper(a, np.sqrt, out)
def from_object(object, dtype=None, copy=True):
return data_object(object._shape, np.array(object._data, dtype=dtype,
copy=copy),
distaxis=object._distaxis)
def from_object(object, dtype, copy, set_locked):
if dtype is None:
dtype = object.dtype
dtypes_equal = dtype == object.dtype
if set_locked and dtypes_equal and locked(object):
return object
if not dtypes_equal and not copy:
raise ValueError("cannot change data type without copying")
if set_locked and not copy:
raise ValueError("cannot lock object without copying")
data = np.array(object._data, dtype=dtype, copy=copy)
if set_locked:
lock(data)
return data_object(object._shape, data, distaxis=object._distaxis)
# This function draws all random numbers on all tasks, to produce the same
......
......@@ -38,8 +38,20 @@ def mprint(*args):
print(*args)
def from_object(object, dtype=None, copy=True):
return np.array(object, dtype=dtype, copy=copy)
def from_object(object, dtype, copy, set_locked):
if dtype is None:
dtype = object.dtype
dtypes_equal = dtype == object.dtype
if set_locked and dtypes_equal and locked(object):
return object
if not dtypes_equal and not copy:
raise ValueError("cannot change data type without copying")
if set_locked and not copy:
raise ValueError("cannot lock object without copying")
res = np.array(object, dtype=dtype, copy=copy)
if set_locked:
lock(res)
return res
def from_random(random_type, shape, dtype=np.float64, **kwargs):
......
......@@ -21,6 +21,17 @@ from .domains.domain import Domain
class DomainTuple(object):
"""Ordered sequence of Domain objects.
This class holds a set of :class:`Domain` objects, which together form the
space on which a :class:`Field` is defined.
Notes
-----
DomainTuples should never be created using the constructor, but rather
via the factory function :attr:`make`!
"""
_tupleCache = {}
def __init__(self, domain):
......@@ -44,6 +55,18 @@ class DomainTuple(object):
@staticmethod
def make(domain):
"""Returns a DomainTuple matching `domain`.
This function checks whether a matching DomainTuple already exists.
If yes, this object is returned, otherwise a new DomainTuple object
is created and returned.
Parameters
----------
domain : Domain or tuple of Domain or DomainTuple
The geometrical structure for which the DomainTuple shall be
obtained.
"""
if isinstance(domain, DomainTuple):
return domain
domain = DomainTuple._parse_domain(domain)
......@@ -75,14 +98,24 @@ class DomainTuple(object):
@property
def shape(self):
"""tuple of int: number of pixels along each axis
The shape of the array-like object required to store information
living on the DomainTuple.
"""
return self._shape
@property
def size(self):
"""int : total number of pixels.
Equivalent to the products over all entries in the object's shape.
"""
return self._size
@property
def axes(self):
"""tuple of tuple of int : shapes of the underlying domains"""
return self._axtuple
def __len__(self):
......
......@@ -239,5 +239,5 @@ class PowerSpace(StructuredDomain):
@property
def k_lengths(self):
"""numpy.ndarray(float) : sorted array of all k-vector lengths."""
"""numpy.ndarray(float) : k-vector length for each bin."""
return self._k_lengths
......@@ -45,23 +45,26 @@ class Field(object):
dtype : type
A numpy.type. Most common are float and complex.
copy: bool
copy : bool
"""
def __init__(self, domain=None, val=None, dtype=None, copy=False):
def __init__(self, domain=None, val=None, dtype=None, copy=False,
locked=False):
self._domain = self._infer_domain(domain=domain, val=val)
dtype = self._infer_dtype(dtype=dtype, val=val)
if isinstance(val, Field):
if self._domain != val._domain:
raise ValueError("Domain mismatch")
self._val = dobj.from_object(val.val, dtype=dtype, copy=copy)
self._val = dobj.from_object(val.val, dtype=dtype, copy=copy,
set_locked=locked)
elif (np.isscalar(val)):
self._val = dobj.full(self._domain.shape, dtype=dtype,
fill_value=val)
elif isinstance(val, dobj.data_object):
if self._domain.shape == val.shape:
self._val = dobj.from_object(val, dtype=dtype, copy=copy)
self._val = dobj.from_object(val, dtype=dtype, copy=copy,
set_locked=locked)
else:
raise ValueError("Shape mismatch")
elif val is None:
......@@ -69,6 +72,9 @@ class Field(object):
else:
raise TypeError("unknown source type")
if locked:
dobj.lock(self._val)
@staticmethod
def full(domain, val, dtype=None):
if not np.isscalar(val):
......@@ -251,9 +257,7 @@ class Field(object):
"""
if self.locked:
return self
res = Field(val=self, copy=True)
res.lock()
return res
return Field(val=self, copy=True, locked=True)
def scalar_weight(self, spaces=None):
if np.isscalar(spaces):
......
......@@ -90,6 +90,7 @@ class CriticalPowerEnergy(Energy):
Dist = PowerDistributor(target=self.m.domain,
power_space=self.position.domain[0])
if self.D is not None:
# MR FIXME: we should use stuff from probing.utils for this
w = Field.zeros(self.position.domain, dtype=self.m.dtype)
for i in range(self.samples):
sample = self.D.draw_sample() + self.m
......
......@@ -239,8 +239,7 @@ def create_power_operator(domain, power_spectrum, space=None, dtype=None):
space = int(space)
return DiagonalOperator(
create_power_field(domain[space], power_spectrum, dtype),
domain=domain,
spaces=space)
domain=domain, spaces=space)
def create_composed_ht_operator(domain, codomain=None):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment