Skip to content
Snippets Groups Projects
Commit 27f4f66b authored by Philipp Arras's avatar Philipp Arras
Browse files

Proper dtype handling

parent d83a042e
Branches
Tags
1 merge request!458Introduce SamplingDtypeEnabler
Pipeline #74867 passed
......@@ -103,7 +103,10 @@ class MultiField(Operator):
@staticmethod
def from_random(random_type, domain, dtype=np.float64, **kwargs):
domain = MultiDomain.make(domain)
if dtype in [np.float64, np.complex128]:
if isinstance(dtype, dict):
dtype = {kk: np.dtype(dt) for kk, dt in dtype.items()}
else:
dtype = np.dtype(dtype)
dtype = {kk: dtype for kk in domain.keys()}
dct = {kk: Field.from_random(random_type, domain[kk], dtype[kk], **kwargs)
for kk in domain.keys()}
......
......@@ -34,18 +34,18 @@ def _check_sampling_dtype(domain, dtypes):
if dtypes is None:
return
if isinstance(domain, DomainTuple):
dtypes = {'': dtypes}
np.dtype(dtypes)
return
elif isinstance(domain, MultiDomain):
if dtypes in [np.float64, np.complex128]:
if isinstance(dtypes, dict):
for dt in dtypes.values():
np.dtype(dt)
if set(domain.keys()) == set(dtypes.keys()):
return
else:
np.dtype(dtypes)
return
dtypes = dtypes.values()
if set(domain.keys()) != set(dtypes.keys()):
raise ValueError
else:
raise TypeError
for dt in dtypes.values():
if dt not in [np.float64, np.complex128]:
raise ValueError
raise TypeError
def _field_to_dtype(field):
......@@ -332,7 +332,8 @@ class StudentTEnergy(EnergyOperator):
E_\\theta(f) = -\\log \\text{StudentT}_\\theta(f)
= \\frac{\\theta + 1}{2} \\log(1 + \\frac{f^2}{\\theta}),
where f is a field defined on `domain`.
where f is a field defined on `domain`. Assumes that the data is `float64`
for sampling.
Parameters
----------
......@@ -342,12 +343,9 @@ class StudentTEnergy(EnergyOperator):
Degree of freedom parameter for the student t distribution
"""
def __init__(self, domain, theta, sampling_dtype=np.float64):
def __init__(self, domain, theta):
self._domain = DomainTuple.make(domain)
self._theta = theta
self._sampling_dtype = sampling_dtype
if sampling_dtype == np.complex128:
raise NotImplementedError('Complex data not supported yet')
def apply(self, x):
self._check_input(x)
......@@ -355,9 +353,7 @@ class StudentTEnergy(EnergyOperator):
if not x.want_metric:
return res
met = makeOp((self._theta+1) / (self._theta+3), self.domain)
if self._sampling_dtype is not None:
met = SamplingDtypeSetter(met, self._sampling_dtype)
return res.add_metric(met)
return res.add_metric(SamplingDtypeSetter(met, np.float64))
class BernoulliEnergy(EnergyOperator):
......@@ -420,7 +416,8 @@ class StandardHamiltonian(EnergyOperator):
ic_samp : IterationController
Tells an internal :class:`SamplingEnabler` which convergence criterion
to use to draw Gaussian samples.
sampling_dtype : FIXME
prior_dtype : numpy.dtype or dict of numpy.dtype, optional
Data type of prior used for sampling.
See also
--------
......@@ -429,9 +426,9 @@ class StandardHamiltonian(EnergyOperator):
`<https://arxiv.org/abs/1812.04403>`_
"""
def __init__(self, lh, ic_samp=None, _c_inp=None, sampling_dtype=np.float64):
def __init__(self, lh, ic_samp=None, _c_inp=None, prior_dtype=np.float64):
self._lh = lh
self._prior = GaussianEnergy(domain=lh.domain, sampling_dtype=sampling_dtype)
self._prior = GaussianEnergy(domain=lh.domain, sampling_dtype=prior_dtype)
if _c_inp is not None:
_, self._prior = self._prior.simplify_for_constant_input(_c_inp)
self._ic_samp = ic_samp
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment