Commit f9ca9bad authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'gauss_sampling_dtype' into 'NIFTy_6'

Gauss sampling dtype

See merge request !401
parents 6fa98577 bb1069ae
Pipeline #67491 passed with stages
in 14 minutes and 56 seconds
......@@ -14,6 +14,7 @@
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from .. import utilities
from ..linearization import Linearization
......@@ -78,7 +79,7 @@ class MetricGaussianKL(Energy):
def __init__(self, mean, hamiltonian, n_samples, constants=[],
point_estimates=[], mirror_samples=False,
napprox=0, _samples=None):
napprox=0, _samples=None, lh_sampling_dtype=np.float64):
super(MetricGaussianKL, self).__init__(mean)
if not isinstance(hamiltonian, StandardHamiltonian):
......@@ -99,7 +100,8 @@ class MetricGaussianKL(Energy):
mean, point_estimates, True)).metric
if napprox > 1:
met._approximation = makeOp(approximation2endo(met, napprox))
_samples = tuple(met.draw_sample(from_inverse=True)
_samples = tuple(met.draw_sample(from_inverse=True,
dtype=lh_sampling_dtype)
for _ in range(n_samples))
if mirror_samples:
_samples += tuple(-s for s in _samples)
......@@ -120,11 +122,13 @@ class MetricGaussianKL(Energy):
self._grad = g * (1./len(self._samples))
self._metric = None
self._napprox = napprox
self._sampdt = lh_sampling_dtype
def at(self, position):
return MetricGaussianKL(position, self._hamiltonian, 0,
self._constants, self._point_estimates,
napprox=self._napprox, _samples=self._samples)
napprox=self._napprox, _samples=self._samples,
lh_sampling_dtype=self._sampdt)
@property
def value(self):
......
......@@ -132,7 +132,8 @@ class MetricGaussianKL_MPI(Energy):
def __init__(self, mean, hamiltonian, n_samples, constants=[],
point_estimates=[], mirror_samples=False,
napprox=0, _samples=None, seed_offset=0):
napprox=0, _samples=None, seed_offset=0,
lh_sampling_dtype=np.float64):
super(MetricGaussianKL_MPI, self).__init__(mean)
if not isinstance(hamiltonian, StandardHamiltonian):
......@@ -167,10 +168,12 @@ class MetricGaussianKL_MPI(Energy):
else:
_samples.append(((i % 2)*2-1) *
met.draw_sample(from_inverse=True))
met.draw_sample(from_inverse=True,
dtype=lh_sampling_dtype))
else:
np.random.seed(i+seed_offset)
_samples.append(met.draw_sample(from_inverse=True))
_samples.append(met.draw_sample(from_inverse=True,
dtype=lh_sampling_dtype))
np.random.set_state(rand_state)
_samples = tuple(_samples)
if mirror_samples:
......@@ -196,12 +199,13 @@ class MetricGaussianKL_MPI(Energy):
self._val = np_allreduce_sum(v)[()] / self._n_samples
self._grad = allreduce_sum_field(g) / self._n_samples
self._metric = None
self._sampdt = lh_sampling_dtype
def at(self, position):
return MetricGaussianKL_MPI(
position, self._hamiltonian, self._n_samples, self._constants,
self._point_estimates, _samples=self._samples,
seed_offset=self._seed_offset)
seed_offset=self._seed_offset, lh_sampling_dtype=self._sampdt)
@property
def value(self):
......
......@@ -71,7 +71,7 @@ class SamplingEnabler(EndomorphicOperator):
else:
s = self._prior.draw_sample(from_inverse=True)
sp = self._prior(s)
nj = self._likelihood.draw_sample()
nj = self._likelihood.draw_sample(dtype=dtype)
energy = QuadraticEnergy(s, self._op, sp + nj,
_grad=self._likelihood(s) - nj)
inverter = ConjugateGradient(self._ic)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment