diff --git a/nifty6/multi_field.py b/nifty6/multi_field.py index d6a5aadbfdb8f4f792f326a11ed2a0bb222e2c98..b3374f3d2564170bd283f546840ca4ef811acd51 100644 --- a/nifty6/multi_field.py +++ b/nifty6/multi_field.py @@ -103,7 +103,10 @@ class MultiField(Operator): @staticmethod def from_random(random_type, domain, dtype=np.float64, **kwargs): domain = MultiDomain.make(domain) - if dtype in [np.float64, np.complex128]: + if isinstance(dtype, dict): + dtype = {kk: np.dtype(dt) for kk, dt in dtype.items()} + else: + dtype = np.dtype(dtype) dtype = {kk: dtype for kk in domain.keys()} dct = {kk: Field.from_random(random_type, domain[kk], dtype[kk], **kwargs) for kk in domain.keys()} diff --git a/nifty6/operators/energy_operators.py b/nifty6/operators/energy_operators.py index 7496ab07f7d72921ba135fe9760b758e120d562f..aab83ca78376f496f374dafba79cc29de92e1d53 100644 --- a/nifty6/operators/energy_operators.py +++ b/nifty6/operators/energy_operators.py @@ -34,18 +34,18 @@ def _check_sampling_dtype(domain, dtypes): if dtypes is None: return if isinstance(domain, DomainTuple): - dtypes = {'': dtypes} + np.dtype(dtypes) + return elif isinstance(domain, MultiDomain): - if dtypes in [np.float64, np.complex128]: + if isinstance(dtypes, dict): + for dt in dtypes.values(): + np.dtype(dt) + if set(domain.keys()) == set(dtypes.keys()): + return + else: + np.dtype(dtypes) return - dtypes = dtypes.values() - if set(domain.keys()) != set(dtypes.keys()): - raise ValueError - else: - raise TypeError - for dt in dtypes.values(): - if dt not in [np.float64, np.complex128]: - raise ValueError + raise TypeError def _field_to_dtype(field): @@ -332,7 +332,8 @@ class StudentTEnergy(EnergyOperator): E_\\theta(f) = -\\log \\text{StudentT}_\\theta(f) = \\frac{\\theta + 1}{2} \\log(1 + \\frac{f^2}{\\theta}), - where f is a field defined on `domain`. + where f is a field defined on `domain`. Assumes that the data is `float64` + for sampling. Parameters ---------- @@ -342,12 +343,9 @@ class StudentTEnergy(EnergyOperator): Degree of freedom parameter for the student t distribution """ - def __init__(self, domain, theta, sampling_dtype=np.float64): + def __init__(self, domain, theta): self._domain = DomainTuple.make(domain) self._theta = theta - self._sampling_dtype = sampling_dtype - if sampling_dtype == np.complex128: - raise NotImplementedError('Complex data not supported yet') def apply(self, x): self._check_input(x) @@ -355,9 +353,7 @@ class StudentTEnergy(EnergyOperator): if not x.want_metric: return res met = makeOp((self._theta+1) / (self._theta+3), self.domain) - if self._sampling_dtype is not None: - met = SamplingDtypeSetter(met, self._sampling_dtype) - return res.add_metric(met) + return res.add_metric(SamplingDtypeSetter(met, np.float64)) class BernoulliEnergy(EnergyOperator): @@ -420,7 +416,8 @@ class StandardHamiltonian(EnergyOperator): ic_samp : IterationController Tells an internal :class:`SamplingEnabler` which convergence criterion to use to draw Gaussian samples. - sampling_dtype : FIXME + prior_dtype : numpy.dtype or dict of numpy.dtype, optional + Data type of prior used for sampling. See also -------- @@ -429,9 +426,9 @@ class StandardHamiltonian(EnergyOperator): `<https://arxiv.org/abs/1812.04403>`_ """ - def __init__(self, lh, ic_samp=None, _c_inp=None, sampling_dtype=np.float64): + def __init__(self, lh, ic_samp=None, _c_inp=None, prior_dtype=np.float64): self._lh = lh - self._prior = GaussianEnergy(domain=lh.domain, sampling_dtype=sampling_dtype) + self._prior = GaussianEnergy(domain=lh.domain, sampling_dtype=prior_dtype) if _c_inp is not None: _, self._prior = self._prior.simplify_for_constant_input(_c_inp) self._ic_samp = ic_samp