Commit 00cc61c9 authored by Philipp Arras's avatar Philipp Arras
Browse files

Rename to SamplingDtypeSetter and add docstrings

parent 58c2740f
......@@ -35,7 +35,7 @@ from .operators.field_zero_padder import FieldZeroPadder
from .operators.inversion_enabler import InversionEnabler
from .operators.mask_operator import MaskOperator
from .operators.regridding_operator import RegriddingOperator
from .operators.sampling_enabler import SamplingEnabler
from .operators.sampling_enabler import SamplingEnabler, SamplingDtypeSetter
from .operators.sandwich_operator import SandwichOperator
from .operators.scaling_operator import ScalingOperator
from .operators.block_diagonal_operator import BlockDiagonalOperator
......
......@@ -11,12 +11,12 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from ..operators.inversion_enabler import InversionEnabler
from ..operators.sampling_enabler import SamplingEnabler
from ..operators.sampling_enabler import SamplingDtypeSetter, SamplingEnabler
from ..operators.sandwich_operator import SandwichOperator
......@@ -48,11 +48,9 @@ def WienerFilterCurvature(R, N, S, iteration_controller=None,
Ninv = N.inverse
Sinv = S.inverse
if data_sampling_dtype is not None:
from ..operators.energy_operators import SamplingDtypeEnabler
Ninv = SamplingDtypeEnabler(Ninv, data_sampling_dtype)
Ninv = SamplingDtypeSetter(Ninv, data_sampling_dtype)
if prior_sampling_dtype is not None:
from ..operators.energy_operators import SamplingDtypeEnabler
Sinv = SamplingDtypeEnabler(Sinv, data_sampling_dtype)
Sinv = SamplingDtypeSetter(Sinv, data_sampling_dtype)
M = SandwichOperator.make(R, Ninv)
if iteration_controller_sampling is not None:
op = SamplingEnabler(M, Sinv, iteration_controller_sampling,
......
......@@ -30,23 +30,47 @@ class EndomorphicOperator(LinearOperator):
for endomorphic operators."""
return self._domain
def draw_sample(self, from_inverse=False):
"""Generates a sample from a Gaussian distribution with zero mean and
covariance given by the operator.
May or may not be implemented. Only optional.
Parameters
----------
from_inverse : bool (default : False)
if True, the sample is drawn from the inverse of the operator
Returns
-------
Field or MultiField
A sample from the Gaussian of given covariance.
"""
raise NotImplementedError
def draw_sample_with_dtype(self, dtype, from_inverse=False):
"""Generate a zero-mean sample
FIXME
"""Generates a sample from a Gaussian distribution with zero mean,
covariance given by the operator and specified data type.
Generates a sample from a Gaussian distribution with zero mean and
covariance given by the operator.
This method is implemented only for operators which actually draw
samples (e.g. `DiagonalOperator`). Operators which process the sample
(like `SandwichOperator`) implement only `draw_sample()`.
May or may not be implemented. Only optional.
Parameters
----------
dtype : numpy datatype FIXME
the data type to be used for the sample
dtype : numpy.dtype or dict of numpy.dtype
Dtype used for sampling from this operator. If the domain of `op`
is a `MultiDomain`, the dtype can either be specified as one value
for all components of the `MultiDomain` or in form of a dictionary
whose keys need to conincide the with keys of the `MultiDomain`.
from_inverse : bool (default : False)
if True, the sample is drawn from the inverse of the operator
Returns
-------
Field
Field or MultiField
A sample from the Gaussian of given covariance.
"""
raise NotImplementedError
......
......@@ -25,10 +25,9 @@ from ..multi_field import MultiField
from ..sugar import makeDomain, makeOp
from .linear_operator import LinearOperator
from .operator import Operator
from .sampling_enabler import SamplingEnabler
from .sampling_enabler import SamplingDtypeSetter, SamplingEnabler
from .scaling_operator import ScalingOperator
from .simple_linear_operators import VdotOperator
from .endomorphic_operator import EndomorphicOperator
def _check_sampling_dtype(domain, dtypes):
......@@ -60,43 +59,6 @@ def _field_to_dtype(field):
return dt
class SamplingDtypeEnabler(EndomorphicOperator):
def __init__(self, endomorphic_operator, dtype):
if not isinstance(endomorphic_operator, EndomorphicOperator):
raise TypeError
if not hasattr(endomorphic_operator, 'draw_sample_with_dtype'):
raise TypeError
dom = endomorphic_operator.domain
if isinstance(dom, MultiDomain):
if dtype in [np.float64, np.complex128]:
dtype = {kk: dtype for kk in dom.keys()}
if set(dtype.keys()) != set(dom.keys()):
raise TypeError
self._dtype = dtype
self._domain = dom
self._capability = endomorphic_operator._capability
self.apply = endomorphic_operator.apply
self._op = endomorphic_operator
def draw_sample(self, from_inverse=False):
"""Generate a zero-mean sample
Generates a sample from a Gaussian distribution with zero mean and
covariance given by the operator.
Parameters
----------
from_inverse : bool (default : False)
if True, the sample is drawn from the inverse of the operator
Returns
-------
Field
A sample from the Gaussian of given covariance.
"""
return self._op.draw_sample_with_dtype(self._dtype, from_inverse=from_inverse)
class EnergyOperator(Operator):
"""Operator which has a scalar domain as target domain.
......@@ -199,7 +161,7 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
return res
mf = {self._r: x.val[self._icov], self._icov: .5*x.val[self._icov]**(-2)}
met = makeOp(MultiField.from_dict(mf))
return res.add_metric(SamplingDtypeEnabler(met, self._sampling_dtype))
return res.add_metric(SamplingDtypeSetter(met, self._sampling_dtype))
class GaussianEnergy(EnergyOperator):
......@@ -262,7 +224,7 @@ class GaussianEnergy(EnergyOperator):
self._op = QuadraticFormOperator(inverse_covariance)
self._met = inverse_covariance
if sampling_dtype is not None:
self._met = SamplingDtypeEnabler(self._met, sampling_dtype)
self._met = SamplingDtypeSetter(self._met, sampling_dtype)
def _checkEquivalence(self, newdom):
newdom = makeDomain(newdom)
......@@ -313,7 +275,7 @@ class PoissonianEnergy(EnergyOperator):
res = x.sum() - x.ptw("log").vdot(self._d)
if not x.want_metric:
return res
return res.add_metric(SamplingDtypeEnabler(makeOp(1./x.val), np.float64))
return res.add_metric(SamplingDtypeSetter(makeOp(1./x.val), np.float64))
class InverseGammaLikelihood(EnergyOperator):
......@@ -359,7 +321,7 @@ class InverseGammaLikelihood(EnergyOperator):
return res
met = makeOp(self._alphap1/(x.val**2))
if self._sampling_dtype is not None:
met = SamplingDtypeEnabler(met, self._sampling_dtype)
met = SamplingDtypeSetter(met, self._sampling_dtype)
return res.add_metric(met)
......@@ -394,7 +356,7 @@ class StudentTEnergy(EnergyOperator):
return res
met = makeOp((self._theta+1) / (self._theta+3), self.domain)
if self._sampling_dtype is not None:
met = SamplingDtypeEnabler(met, self._sampling_dtype)
met = SamplingDtypeSetter(met, self._sampling_dtype)
return res.add_metric(met)
......@@ -429,7 +391,7 @@ class BernoulliEnergy(EnergyOperator):
if not x.want_metric:
return res
met = makeOp(1./(x.val*(1. - x.val)))
return res.add_metric(SamplingDtypeEnabler(met, np.float64))
return res.add_metric(SamplingDtypeSetter(met, np.float64))
class StandardHamiltonian(EnergyOperator):
......
......@@ -19,6 +19,7 @@ import numpy as np
from ..minimization.conjugate_gradient import ConjugateGradient
from ..minimization.quadratic_energy import QuadraticEnergy
from ..multi_domain import MultiDomain
from .endomorphic_operator import EndomorphicOperator
from .operator import Operator
......@@ -96,3 +97,52 @@ class SamplingEnabler(EndomorphicOperator):
indent("\n".join((
"Likelihood:", self._likelihood.__repr__(),
"Prior:", self._prior.__repr__())))))
class SamplingDtypeSetter(EndomorphicOperator):
"""Class that adds the information whether the operator at hand is the
covariance of a real-valued Gaussian or a complex-valued Gaussian
probability distribution.
This wrapper class shall address the following ambiguity which arises when
drawing a sampling from a Gaussian distribution with zero mean and given
covariance. E.g. a `ScalingOperator` with `1.` on its diagonal can be
viewed as the covariance operator of both a real-valued and complex-valued
Gaussian distribution. `SamplingDtypeSetter` specifies this data type.
Parameters
----------
op : EndomorphicOperator
Operator which shall be supplemented with a dtype for sampling. Needs
to be positive definite, hermitian and needs to implement the method
`draw_sample_with_dtype()`. Note that these three properties are not
checked in the constructor.
dtype : numpy.dtype or dict of numpy.dtype
Dtype used for sampling from this operator. If the domain of `op` is a
`MultiDomain`, the dtype can either be specified as one value for all
components of the `MultiDomain` or in form of a dictionary whose keys
need to conincide the with keys of the `MultiDomain`.
"""
def __init__(self, op, dtype):
if not isinstance(op, EndomorphicOperator):
raise TypeError
if not hasattr(op, 'draw_sample_with_dtype'):
raise TypeError
if isinstance(dtype, dict):
dtype = {kk: np.dtype(vv) for kk, vv in dtype.items()}
else:
dtype = np.dtype(dtype)
if isinstance(op.domain, MultiDomain):
if isinstance(dtype, np.dtype):
dtype = {kk: dtype for kk in op.domain.keys()}
if set(dtype.keys()) != set(op.domain.keys()):
raise TypeError
self._dtype = dtype
self._domain = op.domain
self._capability = op.capability
self.apply = op.apply
self._op = op
def draw_sample(self, from_inverse=False):
return self._op.draw_sample_with_dtype(self._dtype,
from_inverse=from_inverse)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment