Commit 0b26ae98 authored by Philipp Arras's avatar Philipp Arras
Browse files

Cosmetics, documentation fixes and __repr__ tweaks

parent e47bd9ed
......@@ -46,7 +46,6 @@ class BlockDiagonalOperator(EndomorphicOperator):
else:
raise TypeError("LinearOperator expected")
def apply(self, x, mode):
self._check_input(x, mode)
val = tuple(op.apply(v, mode=mode) if op is not None else v
......
......@@ -48,6 +48,10 @@ def _check_sampling_dtype(domain, dtypes):
raise TypeError
def _iscomplex(dtype):
return np.issubdtype(dtype, np.complexfloating)
def _field_to_dtype(field):
if isinstance(field, Field):
dt = field.dtype
......@@ -127,10 +131,10 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
The covariance is assumed to be diagonal.
.. math ::
E(s,D) = - \\log G(s, D) = 0.5 (s)^\\dagger D^{-1} (s) + 0.5 tr log(D),
E(s,D) = - \\log G(s, C) = 0.5 (s)^\\dagger C (s) - 0.5 tr log(C),
an information energy for a Gaussian distribution with residual s and
diagonal covariance D.
inverse diagonal covariance C.
The domain of this energy will be a MultiDomain with two keys,
the target will be the scalar domain.
......@@ -139,10 +143,10 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
domain : Domain, DomainTuple, tuple of Domain
domain of the residual and domain of the covariance diagonal.
residual : key
residual_key : key
Residual key of the Gaussian.
inverse_covariance : key
inverse_covariance_key : key
Inverse covariance diagonal key of the Gaussian.
sampling_dtype : np.dtype
......@@ -156,7 +160,7 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
self._domain = MultiDomain.make({self._kr: dom, self._ki: dom})
self._dt = {self._kr: sampling_dtype, self._ki: np.float64}
_check_sampling_dtype(self._domain, self._dt)
self._cplx = np.issubdtype(sampling_dtype, np.complexfloating)
self._cplx = _iscomplex(sampling_dtype)
def apply(self, x):
self._check_input(x)
......
......@@ -276,6 +276,8 @@ class Operator(metaclass=NiftyMeta):
if c_inp is None:
return None, self
# Convention: If c_inp is MultiField, it needs to be defined on a
# subdomain of self._domain
if isinstance(self.domain, MultiDomain):
assert isinstance(c_inp.domain, MultiDomain)
if set(c_inp.keys()) > set(self.domain.keys()):
......
......@@ -102,10 +102,8 @@ class SliceOperator(LinearOperator):
return Field.from_raw(self.domain, res)
def __str__(self):
ss = (
f"{self.__class__.__name__}"
f"({self.domain.shape} -> {self.target.shape})"
)
ss = (f"{self.__class__.__name__}"
f"({self.domain.shape} -> {self.target.shape})")
return ss
......
......@@ -173,16 +173,9 @@ class FieldAdapter(LinearOperator):
return MultiField(self._tgt(mode), (x,))
def __repr__(self):
s = 'FieldAdapter'
dom = isinstance(self._domain, MultiDomain)
tgt = isinstance(self._target, MultiDomain)
if dom and tgt:
s += ' {} <- {}'.format(self._target.keys(), self._domain.keys())
elif dom:
s += ' <- {}'.format(self._domain.keys())
elif tgt:
s += ' {} <-'.format(self._target.keys())
return s
dom = self.domain.keys() if isinstance(self.domain, MultiDomain) else '()'
tgt = self.target.keys() if isinstance(self.target, MultiDomain) else '()'
return f'{tgt} <- {dom}'
class _SlowFieldAdapter(LinearOperator):
......@@ -378,3 +371,6 @@ class PartialExtractor(LinearOperator):
res0 = MultiField.from_dict({key: x[key] for key in x.domain.keys()})
res1 = MultiField.full(self._compldomain, 0.)
return res0.unite(res1)
def __repr__(self):
return f'{self.target.keys()} <- {self.domain.keys()}'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment