From 27f4f66b94ae5b68a8044c69cf43fe9eee7e18f8 Mon Sep 17 00:00:00 2001
From: Philipp Arras <parras@mpa-garching.mpg.de>
Date: Wed, 13 May 2020 16:55:24 +0200
Subject: [PATCH] Proper dtype handling

---
 nifty6/multi_field.py                |  5 +++-
 nifty6/operators/energy_operators.py | 39 +++++++++++++---------------
 2 files changed, 22 insertions(+), 22 deletions(-)

diff --git a/nifty6/multi_field.py b/nifty6/multi_field.py
index d6a5aadbf..b3374f3d2 100644
--- a/nifty6/multi_field.py
+++ b/nifty6/multi_field.py
@@ -103,7 +103,10 @@ class MultiField(Operator):
     @staticmethod
     def from_random(random_type, domain, dtype=np.float64, **kwargs):
         domain = MultiDomain.make(domain)
-        if dtype in [np.float64, np.complex128]:
+        if isinstance(dtype, dict):
+            dtype = {kk: np.dtype(dt) for kk, dt in dtype.items()}
+        else:
+            dtype = np.dtype(dtype)
             dtype = {kk: dtype for kk in domain.keys()}
         dct = {kk: Field.from_random(random_type, domain[kk], dtype[kk], **kwargs)
                for kk in domain.keys()}
diff --git a/nifty6/operators/energy_operators.py b/nifty6/operators/energy_operators.py
index 7496ab07f..aab83ca78 100644
--- a/nifty6/operators/energy_operators.py
+++ b/nifty6/operators/energy_operators.py
@@ -34,18 +34,18 @@ def _check_sampling_dtype(domain, dtypes):
     if dtypes is None:
         return
     if isinstance(domain, DomainTuple):
-        dtypes = {'': dtypes}
+        np.dtype(dtypes)
+        return
     elif isinstance(domain, MultiDomain):
-        if dtypes in [np.float64, np.complex128]:
+        if isinstance(dtypes, dict):
+            for dt in dtypes.values():
+                np.dtype(dt)
+            if set(domain.keys()) == set(dtypes.keys()):
+                return
+        else:
+            np.dtype(dtypes)
             return
-        dtypes = dtypes.values()
-        if set(domain.keys()) != set(dtypes.keys()):
-            raise ValueError
-    else:
-        raise TypeError
-    for dt in dtypes.values():
-        if dt not in [np.float64, np.complex128]:
-            raise ValueError
+    raise TypeError
 
 
 def _field_to_dtype(field):
@@ -332,7 +332,8 @@ class StudentTEnergy(EnergyOperator):
         E_\\theta(f) = -\\log \\text{StudentT}_\\theta(f)
                      = \\frac{\\theta + 1}{2} \\log(1 + \\frac{f^2}{\\theta}),
 
-    where f is a field defined on `domain`.
+    where f is a field defined on `domain`. Assumes that the data is `float64`
+    for sampling.
 
     Parameters
     ----------
@@ -342,12 +343,9 @@ class StudentTEnergy(EnergyOperator):
         Degree of freedom parameter for the student t distribution
     """
 
-    def __init__(self, domain, theta, sampling_dtype=np.float64):
+    def __init__(self, domain, theta):
         self._domain = DomainTuple.make(domain)
         self._theta = theta
-        self._sampling_dtype = sampling_dtype
-        if sampling_dtype == np.complex128:
-            raise NotImplementedError('Complex data not supported yet')
 
     def apply(self, x):
         self._check_input(x)
@@ -355,9 +353,7 @@ class StudentTEnergy(EnergyOperator):
         if not x.want_metric:
             return res
         met = makeOp((self._theta+1) / (self._theta+3), self.domain)
-        if self._sampling_dtype is not None:
-            met = SamplingDtypeSetter(met, self._sampling_dtype)
-        return res.add_metric(met)
+        return res.add_metric(SamplingDtypeSetter(met, np.float64))
 
 
 class BernoulliEnergy(EnergyOperator):
@@ -420,7 +416,8 @@ class StandardHamiltonian(EnergyOperator):
     ic_samp : IterationController
         Tells an internal :class:`SamplingEnabler` which convergence criterion
         to use to draw Gaussian samples.
-    sampling_dtype : FIXME
+    prior_dtype : numpy.dtype or dict of numpy.dtype, optional
+        Data type of prior used for sampling.
 
     See also
     --------
@@ -429,9 +426,9 @@ class StandardHamiltonian(EnergyOperator):
     `<https://arxiv.org/abs/1812.04403>`_
     """
 
-    def __init__(self, lh, ic_samp=None, _c_inp=None, sampling_dtype=np.float64):
+    def __init__(self, lh, ic_samp=None, _c_inp=None, prior_dtype=np.float64):
         self._lh = lh
-        self._prior = GaussianEnergy(domain=lh.domain, sampling_dtype=sampling_dtype)
+        self._prior = GaussianEnergy(domain=lh.domain, sampling_dtype=prior_dtype)
         if _c_inp is not None:
             _, self._prior = self._prior.simplify_for_constant_input(_c_inp)
         self._ic_samp = ic_samp
-- 
GitLab