From 27f4f66b94ae5b68a8044c69cf43fe9eee7e18f8 Mon Sep 17 00:00:00 2001 From: Philipp Arras <parras@mpa-garching.mpg.de> Date: Wed, 13 May 2020 16:55:24 +0200 Subject: [PATCH] Proper dtype handling --- nifty6/multi_field.py | 5 +++- nifty6/operators/energy_operators.py | 39 +++++++++++++--------------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/nifty6/multi_field.py b/nifty6/multi_field.py index d6a5aadbf..b3374f3d2 100644 --- a/nifty6/multi_field.py +++ b/nifty6/multi_field.py @@ -103,7 +103,10 @@ class MultiField(Operator): @staticmethod def from_random(random_type, domain, dtype=np.float64, **kwargs): domain = MultiDomain.make(domain) - if dtype in [np.float64, np.complex128]: + if isinstance(dtype, dict): + dtype = {kk: np.dtype(dt) for kk, dt in dtype.items()} + else: + dtype = np.dtype(dtype) dtype = {kk: dtype for kk in domain.keys()} dct = {kk: Field.from_random(random_type, domain[kk], dtype[kk], **kwargs) for kk in domain.keys()} diff --git a/nifty6/operators/energy_operators.py b/nifty6/operators/energy_operators.py index 7496ab07f..aab83ca78 100644 --- a/nifty6/operators/energy_operators.py +++ b/nifty6/operators/energy_operators.py @@ -34,18 +34,18 @@ def _check_sampling_dtype(domain, dtypes): if dtypes is None: return if isinstance(domain, DomainTuple): - dtypes = {'': dtypes} + np.dtype(dtypes) + return elif isinstance(domain, MultiDomain): - if dtypes in [np.float64, np.complex128]: + if isinstance(dtypes, dict): + for dt in dtypes.values(): + np.dtype(dt) + if set(domain.keys()) == set(dtypes.keys()): + return + else: + np.dtype(dtypes) return - dtypes = dtypes.values() - if set(domain.keys()) != set(dtypes.keys()): - raise ValueError - else: - raise TypeError - for dt in dtypes.values(): - if dt not in [np.float64, np.complex128]: - raise ValueError + raise TypeError def _field_to_dtype(field): @@ -332,7 +332,8 @@ class StudentTEnergy(EnergyOperator): E_\\theta(f) = -\\log \\text{StudentT}_\\theta(f) = \\frac{\\theta + 1}{2} \\log(1 + \\frac{f^2}{\\theta}), - where f is a field defined on `domain`. + where f is a field defined on `domain`. Assumes that the data is `float64` + for sampling. Parameters ---------- @@ -342,12 +343,9 @@ class StudentTEnergy(EnergyOperator): Degree of freedom parameter for the student t distribution """ - def __init__(self, domain, theta, sampling_dtype=np.float64): + def __init__(self, domain, theta): self._domain = DomainTuple.make(domain) self._theta = theta - self._sampling_dtype = sampling_dtype - if sampling_dtype == np.complex128: - raise NotImplementedError('Complex data not supported yet') def apply(self, x): self._check_input(x) @@ -355,9 +353,7 @@ class StudentTEnergy(EnergyOperator): if not x.want_metric: return res met = makeOp((self._theta+1) / (self._theta+3), self.domain) - if self._sampling_dtype is not None: - met = SamplingDtypeSetter(met, self._sampling_dtype) - return res.add_metric(met) + return res.add_metric(SamplingDtypeSetter(met, np.float64)) class BernoulliEnergy(EnergyOperator): @@ -420,7 +416,8 @@ class StandardHamiltonian(EnergyOperator): ic_samp : IterationController Tells an internal :class:`SamplingEnabler` which convergence criterion to use to draw Gaussian samples. - sampling_dtype : FIXME + prior_dtype : numpy.dtype or dict of numpy.dtype, optional + Data type of prior used for sampling. See also -------- @@ -429,9 +426,9 @@ class StandardHamiltonian(EnergyOperator): `<https://arxiv.org/abs/1812.04403>`_ """ - def __init__(self, lh, ic_samp=None, _c_inp=None, sampling_dtype=np.float64): + def __init__(self, lh, ic_samp=None, _c_inp=None, prior_dtype=np.float64): self._lh = lh - self._prior = GaussianEnergy(domain=lh.domain, sampling_dtype=sampling_dtype) + self._prior = GaussianEnergy(domain=lh.domain, sampling_dtype=prior_dtype) if _c_inp is not None: _, self._prior = self._prior.simplify_for_constant_input(_c_inp) self._ic_samp = ic_samp -- GitLab