Commit f507f3aa authored by Philipp Arras's avatar Philipp Arras
Browse files

Merge branch 'new_random' into 'NIFTy_6'

Switch to new numpy random generators

See merge request ift/nifty!428
parents 604044cd 90da361e
Pipeline #71313 passed with stages
in 17 minutes and 12 seconds
Changes since NIFTy 5:
MPI parallelisation over samples in MetricGaussianKL
====================================================
The classes `MetricGaussianKL` and `MetricGaussianKL_MPI` have been unified
into one `MetricGaussianKL` class which has MPI support built in.
New approach for random number generation
=========================================
The code now uses `numpy`'s new `SeedSequence` and `Generator` classes for the
production of random numbers (introduced in numpy 1.17. This greatly simplifies
the generation of reproducible random numbers in the presence of MPI parallelism
and leads to cleaner code overall. Please see the documentation of
`nifty6.random` for details.
Interface Change for non-linear Operators
=========================================
......@@ -17,7 +32,6 @@ behaviour since both `Operator._check_input()` and
`extra.check_jacobian_consistency()` tests for the new conditions to be
fulfilled.
Special functions for complete Field reduction operations
=========================================================
......@@ -66,12 +80,3 @@ User-visible changes:
replaced by a single function called `makeField`
- the property `local_shape` has been removed from `Domain` (and subclasses)
and `DomainTuple`.
Transfer of MPI parallelization into operators:
===============================================
As was already the case with the `MetricGaussianKL_MPI` in NIFTy5, MPI
parallelization in NIFTy6 is handled by specialized MPI-enabled operators.
They are accessible via the `nifty6.mpi` namespace, from which they can be
imported directly: `from nifty6.mpi import MPIenabledOperator`.
%% Cell type:markdown id: tags:
# A NIFTy demonstration
%% Cell type:markdown id: tags:
## IFT: Big Picture
IFT starting point:
$$d = Rs+n$$
Typically, $s$ is a continuous field, $d$ a discrete data vector. Particularly, $R$ is not invertible.
IFT aims at **inverting** the above uninvertible problem in the **best possible way** using Bayesian statistics.
## NIFTy
NIFTy (Numerical Information Field Theory) is a Python framework in which IFT problems can be tackled easily.
Main Interfaces:
- **Spaces**: Cartesian, 2-Spheres (Healpix, Gauss-Legendre) and their respective harmonic spaces.
- **Fields**: Defined on spaces.
- **Operators**: Acting on fields.
%% Cell type:markdown id: tags:
## Wiener Filter: Formulae
### Assumptions
- $d=Rs+n$, $R$ linear operator.
- $\mathcal P (s) = \mathcal G (s,S)$, $\mathcal P (n) = \mathcal G (n,N)$ where $S, N$ are positive definite matrices.
### Posterior
The Posterior is given by:
$$\mathcal P (s|d) \propto P(s,d) = \mathcal G(d-Rs,N) \,\mathcal G(s,S) \propto \mathcal G (s-m,D) $$
where
$$\begin{align}
m &= Dj \\
D^{-1}&= (S^{-1} +R^\dagger N^{-1} R )\\
j &= R^\dagger N^{-1} d
\end{align}$$
Let us implement this in NIFTy!
%% Cell type:markdown id: tags:
## Wiener Filter: Example
- We assume statistical homogeneity and isotropy. Therefore the signal covariance $S$ is diagonal in harmonic space, and is described by a one-dimensional power spectrum, assumed here as $$P(k) = P_0\,\left(1+\left(\frac{k}{k_0}\right)^2\right)^{-\gamma /2},$$
with $P_0 = 0.2, k_0 = 5, \gamma = 4$.
- $N = 0.2 \cdot \mathbb{1}$.
- Number of data points $N_{pix} = 512$.
- reconstruction in harmonic space.
- Response operator:
$$R = FFT_{\text{harmonic} \rightarrow \text{position}}$$
%% Cell type:code id: tags:
``` python
N_pixels = 512 # Number of pixels
def pow_spec(k):
P0, k0, gamma = [.2, 5, 4]
return P0 / ((1. + (k/k0)**2)**(gamma / 2))
```
%% Cell type:markdown id: tags:
## Wiener Filter: Implementation
%% Cell type:markdown id: tags:
### Import Modules
%% Cell type:code id: tags:
``` python
import numpy as np
np.random.seed(40)
import nifty6 as ift
import matplotlib.pyplot as plt
%matplotlib inline
```
%% Cell type:markdown id: tags:
### Implement Propagator
%% Cell type:code id: tags:
``` python
def Curvature(R, N, Sh):
IC = ift.GradientNormController(iteration_limit=50000,
tol_abs_gradnorm=0.1)
# WienerFilterCurvature is (R.adjoint*N.inverse*R + Sh.inverse) plus some handy
# helper methods.
return ift.WienerFilterCurvature(R,N,Sh,iteration_controller=IC,iteration_controller_sampling=IC)
```
%% Cell type:markdown id: tags:
### Conjugate Gradient Preconditioning
- $D$ is defined via:
$$D^{-1} = \mathcal S_h^{-1} + R^\dagger N^{-1} R.$$
In the end, we want to apply $D$ to $j$, i.e. we need the inverse action of $D^{-1}$. This is done numerically (algorithm: *Conjugate Gradient*).
<!--
- One can define the *condition number* of a non-singular and normal matrix $A$:
$$\kappa (A) := \frac{|\lambda_{\text{max}}|}{|\lambda_{\text{min}}|},$$
where $\lambda_{\text{max}}$ and $\lambda_{\text{min}}$ are the largest and smallest eigenvalue of $A$, respectively.
- The larger $\kappa$ the slower Conjugate Gradient.
- By default, conjugate gradient solves: $D^{-1} m = j$ for $m$, where $D^{-1}$ can be badly conditioned. If one knows a non-singular matrix $T$ for which $TD^{-1}$ is better conditioned, one can solve the equivalent problem:
$$\tilde A m = \tilde j,$$
where $\tilde A = T D^{-1}$ and $\tilde j = Tj$.
- In our case $S^{-1}$ is responsible for the bad conditioning of $D$ depending on the chosen power spectrum. Thus, we choose
$$T = \mathcal F^\dagger S_h^{-1} \mathcal F.$$
-->
%% Cell type:markdown id: tags:
### Generate Mock data
- Generate a field $s$ and $n$ with given covariances.
- Calculate $d$.
%% Cell type:code id: tags:
``` python
s_space = ift.RGSpace(N_pixels)
h_space = s_space.get_default_codomain()
HT = ift.HarmonicTransformOperator(h_space, target=s_space)
# Operators
Sh = ift.create_power_operator(h_space, power_spectrum=pow_spec)
R = HT #*ift.create_harmonic_smoothing_operator((h_space,), 0, 0.02)
# Fields and data
sh = Sh.draw_sample()
noiseless_data=R(sh)
noise_amplitude = np.sqrt(0.2)
N = ift.ScalingOperator(s_space, noise_amplitude**2)
n = ift.Field.from_random(domain=s_space, random_type='normal',
std=noise_amplitude, mean=0)
d = noiseless_data + n
j = R.adjoint_times(N.inverse_times(d))
curv = Curvature(R=R, N=N, Sh=Sh)
D = curv.inverse
```
%% Cell type:markdown id: tags:
### Run Wiener Filter
%% Cell type:code id: tags:
``` python
m = D(j)
```
%% Cell type:markdown id: tags:
### Signal Reconstruction
%% Cell type:code id: tags:
``` python
# Get signal data and reconstruction data
s_data = HT(sh).val
m_data = HT(m).val
d_data = d.val
plt.figure(figsize=(15,10))
plt.plot(s_data, 'r', label="Signal", linewidth=3)
plt.plot(d_data, 'k.', label="Data")
plt.plot(m_data, 'k', label="Reconstruction",linewidth=3)
plt.title("Reconstruction")
plt.legend()
plt.show()
```
%% Cell type:code id: tags:
``` python
plt.figure(figsize=(15,10))
plt.plot(s_data - s_data, 'r', label="Signal", linewidth=3)
plt.plot(d_data - s_data, 'k.', label="Data")
plt.plot(m_data - s_data, 'k', label="Reconstruction",linewidth=3)
plt.axhspan(-noise_amplitude,noise_amplitude, facecolor='0.9', alpha=.5)
plt.title("Residuals")
plt.legend()
plt.show()
```
%% Cell type:markdown id: tags:
### Power Spectrum
%% Cell type:code id: tags:
``` python
s_power_data = ift.power_analyze(sh).val
m_power_data = ift.power_analyze(m).val
plt.figure(figsize=(15,10))
plt.loglog()
plt.xlim(1, int(N_pixels/2))
ymin = min(m_power_data)
plt.ylim(ymin, 1)
xs = np.arange(1,int(N_pixels/2),.1)
plt.plot(xs, pow_spec(xs), label="True Power Spectrum", color='k',alpha=0.5)
plt.plot(s_power_data, 'r', label="Signal")
plt.plot(m_power_data, 'k', label="Reconstruction")
plt.axhline(noise_amplitude**2 / N_pixels, color="k", linestyle='--', label="Noise level", alpha=.5)
plt.axhspan(noise_amplitude**2 / N_pixels, ymin, facecolor='0.9', alpha=.5)
plt.title("Power Spectrum")
plt.legend()
plt.show()
```
%% Cell type:markdown id: tags:
## Wiener Filter on Incomplete Data
%% Cell type:code id: tags:
``` python
# Operators
Sh = ift.create_power_operator(h_space, power_spectrum=pow_spec)
N = ift.ScalingOperator(s_space, noise_amplitude**2)
# R is defined below
# Fields
sh = Sh.draw_sample()
s = HT(sh)
n = ift.Field.from_random(domain=s_space, random_type='normal',
std=noise_amplitude, mean=0)
```
%% Cell type:markdown id: tags:
### Partially Lose Data
%% Cell type:code id: tags:
``` python
l = int(N_pixels * 0.2)
h = int(N_pixels * 0.2 * 2)
mask = np.full(s_space.shape, 1.)
mask[l:h] = 0
mask = ift.Field.from_raw(s_space, mask)
R = ift.DiagonalOperator(mask)(HT)
n = n.val_rw()
n[l:h] = 0
n = ift.Field.from_raw(s_space, n)
d = R(sh) + n
```
%% Cell type:code id: tags:
``` python
curv = Curvature(R=R, N=N, Sh=Sh)
D = curv.inverse
j = R.adjoint_times(N.inverse_times(d))
m = D(j)
```
%% Cell type:markdown id: tags:
### Compute Uncertainty
%% Cell type:code id: tags:
``` python
m_mean, m_var = ift.probe_with_posterior_samples(curv, HT, 200)
```
%% Cell type:markdown id: tags:
### Get data
%% Cell type:code id: tags:
``` python
# Get signal data and reconstruction data
s_data = s.val
m_data = HT(m).val
m_var_data = m_var.val
uncertainty = np.sqrt(m_var_data)
d_data = d.val_rw()
# Set lost data to NaN for proper plotting
d_data[d_data == 0] = np.nan
```
%% Cell type:code id: tags:
``` python
fig = plt.figure(figsize=(15,10))
plt.axvspan(l, h, facecolor='0.8',alpha=0.5)
plt.fill_between(range(N_pixels), m_data - uncertainty, m_data + uncertainty, facecolor='0.5', alpha=0.5)
plt.plot(s_data, 'r', label="Signal", alpha=1, linewidth=3)
plt.plot(d_data, 'k.', label="Data")
plt.plot(m_data, 'k', label="Reconstruction", linewidth=3)
plt.title("Reconstruction of incomplete data")
plt.legend()
```
%% Cell type:markdown id: tags:
# 2d Example
%% Cell type:code id: tags:
``` python
N_pixels = 256 # Number of pixels
sigma2 = 2. # Noise variance
def pow_spec(k):
P0, k0, gamma = [.2, 2, 4]
return P0 * (1. + (k/k0)**2)**(-gamma/2)
s_space = ift.RGSpace([N_pixels, N_pixels])
```
%% Cell type:code id: tags:
``` python
h_space = s_space.get_default_codomain()
HT = ift.HarmonicTransformOperator(h_space,s_space)
# Operators
Sh = ift.create_power_operator(h_space, power_spectrum=pow_spec)
N = ift.ScalingOperator(s_space, sigma2)
# Fields and data
sh = Sh.draw_sample()
n = ift.Field.from_random(domain=s_space, random_type='normal',
std=np.sqrt(sigma2), mean=0)
# Lose some data
l = int(N_pixels * 0.33)
h = int(N_pixels * 0.33 * 2)
mask = np.full(s_space.shape, 1.)
mask[l:h,l:h] = 0.
mask = ift.Field.from_raw(s_space, mask)
R = ift.DiagonalOperator(mask)(HT)
n = n.val_rw()
n[l:h, l:h] = 0
n = ift.Field.from_raw(s_space, n)
curv = Curvature(R=R, N=N, Sh=Sh)
D = curv.inverse
d = R(sh) + n
j = R.adjoint_times(N.inverse_times(d))
# Run Wiener filter
m = D(j)
# Uncertainty
m_mean, m_var = ift.probe_with_posterior_samples(curv, HT, 20)
# Get data
s_data = HT(sh).val
m_data = HT(m).val
m_var_data = m_var.val
d_data = d.val
uncertainty = np.sqrt(np.abs(m_var_data))
```
%% Cell type:code id: tags:
``` python
cm = ['magma', 'inferno', 'plasma', 'viridis'][1]
mi = np.min(s_data)
ma = np.max(s_data)
fig, axes = plt.subplots(1, 2, figsize=(15, 7))
data = [s_data, d_data]
caption = ["Signal", "Data"]
for ax in axes.flat:
im = ax.imshow(data.pop(0), interpolation='nearest', cmap=cm, vmin=mi,
vmax=ma)
ax.set_title(caption.pop(0))
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
fig.colorbar(im, cax=cbar_ax)
```
%% Cell type:code id: tags:
``` python
mi = np.min(s_data)
ma = np.max(s_data)
fig, axes = plt.subplots(3, 2, figsize=(15, 22.5))
sample = HT(curv.draw_sample(from_inverse=True)+m).val
post_mean = (m_mean + HT(m)).val
data = [s_data, m_data, post_mean, sample, s_data - m_data, uncertainty]
caption = ["Signal", "Reconstruction", "Posterior mean", "Sample", "Residuals", "Uncertainty Map"]
for ax in axes.flat:
im = ax.imshow(data.pop(0), interpolation='nearest', cmap=cm, vmin=mi, vmax=ma)
ax.set_title(caption.pop(0))
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([.85, 0.15, 0.05, 0.7])
fig.colorbar(im, cax=cbar_ax)
```
%% Cell type:markdown id: tags:
### Is the uncertainty map reliable?
%% Cell type:code id: tags:
``` python
precise = (np.abs(s_data-m_data) < uncertainty)
print("Error within uncertainty map bounds: " + str(np.sum(precise) * 100 / N_pixels**2) + "%")
plt.figure(figsize=(15,10))
plt.imshow(precise.astype(float), cmap="brg")
plt.colorbar()
```
%% Cell type:markdown id: tags:
# Start Coding
## NIFTy Repository + Installation guide
https://gitlab.mpcdf.mpg.de/ift/NIFTy
NIFTy v5 **more or less stable!**
......
......@@ -5,8 +5,6 @@ import numpy as np
import nifty6 as ift
np.random.seed(40)
N0s, a0s, b0s, c0s = [], [], [], []
for ii in range(10, 26):
......@@ -15,15 +13,15 @@ for ii in range(10, 26):
N = int(2**ii)
print('N = {}'.format(N))
uv = np.random.rand(N, 2) - 0.5
vis = np.random.randn(N) + 1j*np.random.randn(N)
rng = ift.random.current_rng()
uv = rng.uniform(-.5, .5, (N,2))
vis = rng.normal(0., 1., N) + 1j*rng.normal(0., 1., N)
uvspace = ift.RGSpace((nu, nv))
visspace = ift.UnstructuredDomain(N)
img = np.random.randn(nu*nv)
img = img.reshape((nu, nv))
img = rng.standard_normal((nu, nv))
img = ift.makeField(uvspace, img)
t0 = time()
......
......@@ -27,8 +27,6 @@ import numpy as np
import nifty6 as ift
if __name__ == '__main__':
np.random.seed(41)
# Set up the position space of the signal
mode = 2
if mode == 0:
......@@ -62,7 +60,7 @@ if __name__ == '__main__':
p = R(sky)
mock_position = ift.from_random('normal', harmonic_space)
tmp = p(mock_position).val.astype(np.float64)
data = np.random.binomial(1, tmp)
data = ift.random.current_rng().binomial(1, tmp)
data = ift.Field.from_raw(R.target, data)
# Compute likelihood and Hamiltonian
......
......@@ -46,8 +46,6 @@ def make_random_mask():
if __name__ == '__main__':
np.random.seed(42)
# Choose space on which the signal field is defined
if len(sys.argv) == 2:
mode = int(sys.argv[1])
......
......@@ -44,8 +44,6 @@ def exposure_2d():
if __name__ == '__main__':
np.random.seed(42)
# Choose space on which the signal field is defined
if len(sys.argv) == 2:
mode = int(sys.argv[1])
......@@ -94,7 +92,7 @@ if __name__ == '__main__':
lamb = R(sky)
mock_position = ift.from_random('normal', domain)
data = lamb(mock_position)
data = np.random.poisson(data.val.astype(np.float64))
data = ift.random.current_rng().poisson(data.val.astype(np.float64))
data = ift.Field.from_raw(d_space, data)
likelihood = ift.PoissonianEnergy(data) @ lamb
......
......@@ -33,20 +33,18 @@ import nifty6 as ift
def random_los(n_los):
starts = list(np.random.uniform(0, 1, (n_los, 2)).T)
ends = list(np.random.uniform(0, 1, (n_los, 2)).T)
starts = list(ift.random.current_rng().random((n_los, 2)).T)
ends = list(ift.random.current_rng().random((n_los, 2)).T)
return starts, ends
def radial_los(n_los):
starts = list(np.random.uniform(0, 1, (n_los, 2)).T)
ends = list(0.5 + 0*np.random.uniform(0, 1, (n_los, 2)).T)
starts = list(ift.random.current_rng().random((n_los, 2)).T)
ends = list(0.5 + 0*ift.random.current_rng().random((n_los, 2)).T)
return starts, ends
if __name__ == '__main__':
np.random.seed(420)
# Choose between random line-of-sight response (mode=0) and radial lines
# of sight (mode=1)
if len(sys.argv) == 2:
......
......@@ -44,20 +44,18 @@ class SingleDomain(ift.LinearOperator):
def random_los(n_los):
starts = list(np.random.uniform(0, 1, (n_los, 2)).T)
ends = list(np.random.uniform(0, 1, (n_los, 2)).T)
starts = list(ift.random.current_rng().random((n_los, 2)).T)
ends = list(ift.random.current_rng().random((n_los, 2)).T)
return starts, ends
def radial_los(n_los):
starts = list(np.random.uniform(0, 1, (n_los, 2)).T)
ends = list(0.5 + 0*np.random.uniform(0, 1, (n_los, 2)).T)
starts = list(ift.random.current_rng().random((n_los, 2)).T)
ends = list(0.5 + 0*ift.random.current_rng().random((n_los, 2)).T)
return starts, ends
if __name__ == '__main__':
np.random.seed(43)
# Choose between random line-of-sight response (mode=0) and radial lines
# of sight (mode=1)
if len(sys.argv) == 2:
......
......@@ -19,7 +19,6 @@ import matplotlib.pyplot as plt
import numpy as np
import nifty6 as ift
np.random.seed(12)
def polynomial(coefficients, sampling_points):
......@@ -86,7 +85,7 @@ if __name__ == '__main__':
N_params = 10
N_samples = 100
size = (12,)
x = np.random.random(size) * 10
x = ift.random.current_rng().random(size) * 10
y = np.sin(x**2) * x**3
var = np.full_like(y, y.var() / 10)
var[-2] *= 4
......
from .version import __version__
from . import random
random.seed(42)
from .domains.domain import Domain
from .domains.structured_domain import StructuredDomain
......
......@@ -278,17 +278,21 @@ def check_jacobian_consistency(op, loc, tol=1e-8, ntries=100, perf_check=True):
dir = loc2-loc
locnext = loc2
dirnorm = dir.norm()
hist = []
for i in range(50):
locmid = loc + 0.5*dir
linmid = op(Linearization.make_var(locmid))
dirder = linmid.jac(dir)
numgrad = (lin2.val-lin.val)
xtol = tol * dirder.norm() / np.sqrt(dirder.size)
hist.append((numgrad-dirder).norm())
# print(len(hist),hist[-1])
if (abs(numgrad-dirder) <= xtol).s_all():
break
dir = dir*0.5
dirnorm *= 0.5
loc2, lin2 = locmid, linmid
else:
print(hist)
raise ValueError("gradient and value seem inconsistent")
loc = locnext
......@@ -25,6 +25,7 @@ from ..field import Field
from ..linearization import Linearization
from ..operators.operator import Operator
from ..sugar import makeOp
from .. import random
def _f_on_np(f, arr):
......@@ -67,7 +68,8 @@ class _InterpolationOperator(Operator):
if table_func is not None:
if inv_table_func is None:
raise ValueError
a = func(np.random.randn(10))
# MR FIXME: not sure whether we should have this in production code
a = func(random.current_rng().random(10))
a1 = _f_on_np(lambda x: inv_table_func(table_func(x)), a)
np.testing.assert_allclose(a, a1)
self._table = _f_on_np(table_func, self._table)
......
......@@ -11,19 +11,66 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from .. import utilities
from .. import random, utilities
from ..field import Field
from ..linearization import Linearization
from ..multi_field import MultiField
from ..operators.endomorphic_operator import EndomorphicOperator
from ..operators.energy_operators import StandardHamiltonian
from ..probing import approximation2endo
from ..sugar import makeOp
from ..sugar import full, makeOp
from .energy import Energy
def _shareRange(nwork, nshares, myshare):
nbase = nwork//nshares
additional = nwork % nshares
lo = myshare*nbase + min(myshare, additional)
hi = lo + nbase + int(myshare < additional)
return lo, hi
def _np_allreduce_sum(comm, arr):
if comm is None:
return arr
from mpi4py import MPI
arr = np.array(arr)
res = np.empty_like(arr)
comm.Allreduce(arr, res, MPI.SUM)
return res
def _allreduce_sum_field(comm, fld):
if comm is None:
return fld
if isinstance(fld, Field):
return Field(fld.domain, _np_allreduce_sum(fld.val))
res = tuple(
Field(f.domain, _np_allreduce_sum(comm, f.val))
for f in fld.values())
return MultiField(fld.domain, res)
class _KLMetric(EndomorphicOperator):
def __init__(self, KL):
self._KL = KL
self._capability = self.TIMES | self.ADJOINT_TIMES
self._domain = KL.position.domain
def apply(self, x, mode):
self._check_input(x, mode)
return self._KL.apply_metric(x)
def draw_sample(self, from_inverse=False, dtype=np.float64):
return self._KL._metric_sample(from_inverse, dtype)
class MetricGaussianKL(Energy):
"""Provides the sampled Kullback-Leibler divergence between a distribution
and a Metric Gaussian.
......@@ -38,6 +85,7 @@ class MetricGaussianKL(Energy):
typically nonlinear structure of the true distribution these samples have
to be updated eventually by intantiating `MetricGaussianKL` again. For the
true probability distribution the standard parametrization is assumed.
The samples of this class can be distributed among MPI tasks.
Parameters
----------
......@@ -62,6 +110,17 @@ class MetricGaussianKL(Energy):
napprox : int
Number of samples for computing preconditioner for sampling. No
preconditioning is done by default.
comm : MPI communicator or None
If not None, samples will be distributed as evenly as possible
across this communicator. If `mirror_samples` is set, then a sample and
its mirror image will always reside on the same task.
lh_sampling_dtype : type
Determines which dtype in data space shall be used for drawing samples
from the metric. If the inference is based on complex data,
lh_sampling_dtype shall be set to complex accordingly. The reason for
the presence of this parameter is that metric of the likelihood energy
is just an `Operator` which does not know anything about the dtype of
the fields on which it acts. Default is float64.
_samples : None
Only a parameter for internal uses. Typically not to be set by users.
......@@ -79,7 +138,8 @@ class MetricGaussianKL(Energy):
def __init__(self, mean, hamiltonian, n_samples, constants=[],
point_estimates=[], mirror_samples=False,
napprox=0, _samples=None, lh_sampling_dtype=np.float64):
napprox=0, comm=None, _samples=None,
lh_sampling_dtype=np.float64):
super(MetricGaussianKL, self).__init__(mean)
if not isinstance(hamiltonian, StandardHamiltonian):
......@@ -88,47 +148,72 @@ class MetricGaussianKL(Energy):
raise ValueError
if not isinstance(n_samples, int):
raise TypeError
self._constants = list(constants)
self._point_estimates = list(point_estimates)
self._constants = tuple(constants)
self._point_estimates = tuple(point_estimates)
if not isinstance(mirror_samples, bool):
raise TypeError
self._hamiltonian = hamiltonian
self._n_samples = int(n_samples)
if comm is not None:
self._comm = comm
ntask = self._comm.Get_size()
rank = self._comm.Get_rank()
self._lo, self._hi = _shareRange(self._n_samples, ntask, rank)
else:
self._comm = None
self._lo, self._hi = 0, self._n_samples
self._mirror_samples = bool(mirror_samples)
self._n_eff_samples = self._n_samples
if self._mirror_samples:
self._n_eff_samples *= 2
if _samples is None:
met = hamiltonian(Linearization.make_partial_var(
mean, point_estimates, True)).metric
if napprox > 1:
mean, self._point_estimates, True)).metric
if napprox >= 1:
met._approximation = makeOp(approximation2endo(met, napprox))
_samples = tuple(met.draw_sample(from_inverse=True,
dtype=lh_sampling_dtype)
for _ in range(n_samples))
if mirror_samples:
_samples += tuple(-s for s in _samples)
_samples = []
sseq = random.spawn_sseq(self._n_samples)
for i in range(self._lo, self._hi):
random.push_sseq(sseq[i])
_samples.append(met.draw_sample(from_inverse=True,
dtype=lh_sampling_dtype))
random.pop_sseq()
_samples = tuple(_samples)
else:
if len(_samples) != self._hi-self._lo:
raise ValueError("# of samples mismatch")
self._samples = _samples
# FIXME Use simplify for constant input instead
self._lin = Linearization.make_partial_var(mean, constants)
self._lin = Linearization.make_partial_var(mean, self._constants)
v, g = None, None
for s in self._samples:
tmp = self._hamiltonian(self._lin+s)
if v is None:
v = tmp.val.val[()]
g = tmp.gradient
else:
v += tmp.val.val[()]
g = g + tmp.gradient
self._val = v / len(self._samples)
self._grad = g * (1./len(self._samples))
if len(self._samples) == 0: # hack if there are too many MPI tasks
tmp = self._hamiltonian(self._lin)
v = 0. * tmp.val.val
g = 0. * tmp.gradient
else:
for s in self._samples:
tmp = self._hamiltonian(self._lin+s)
if self._mirror_samples:
tmp = tmp + self._hamiltonian(self._lin-s)
if v is None:
v = tmp.val.val_rw()
g = tmp.gradient
else:
v += tmp.val.val
g = g + tmp.gradient
self._val = _np_allreduce_sum(self._comm, v)[()] / self._n_eff_samples
self._grad = _allreduce_sum_field(self._comm, g) / self._n_eff_samples
self._metric = None
self._napprox = napprox
self._sampdt = lh_sampling_dtype
def at(self, position):
return MetricGaussianKL(position, self._hamiltonian, 0,
self._constants, self._point_estimates,
napprox=self._napprox, _samples=self._samples,
lh_sampling_dtype=self._sampdt)
return MetricGaussianKL(
position, self._hamiltonian, self._n_samples, self._constants,
self._point_estimates, self._mirror_samples, comm=self._comm,
_samples=self._samples, lh_sampling_dtype=self._sampdt)
@property
def value(self):
......@@ -139,30 +224,48 @@ class MetricGaussianKL(Energy):
return self._grad
def _get_metric(self):
lin = self._lin.with_want_metric()
if self._metric is None:
lin = self._lin.with_want_metric()
mymap = map(lambda v: self._hamiltonian(lin+v).metric,
self._samples)
self._unscaled_metric = utilities.my_sum(mymap)
self._metric = self._unscaled_metric.scale(1./len(self._samples))
def unscaled_metric(self):
self._get_metric()
return self._unscaled_metric, 1/len(self._samples)
if len(self._samples) == 0: # hack if there are too many MPI tasks
self._metric = self._hamiltonian(lin).metric.scale(0.)
else:
mymap = map(lambda v: self._hamiltonian(lin+v).metric,
self._samples)
self.unscaled_metric = utilities.my_sum(mymap)
self._metric = self.unscaled_metric.scale(1./self._n_eff_samples)
def apply_metric(self, x):
self._get_metric()
return self._metric(x)