Commit d705ed0e authored by Philipp Arras's avatar Philipp Arras
Browse files

Merge branch 'sample_list_improvements' into 'NIFTy_8'

SampleList improvements

See merge request !696
parents 8c7841c3 5e48aa96
Pipeline #113649 passed with stages
in 30 minutes and 37 seconds
......@@ -6,6 +6,8 @@ docs/source/user/custom_nonlinearities.rst
docs/source/user/getting_started_4_CorrelatedFields.rst
# custom
*.h5
*.hdf5
*.txt
*.pickle
setup.cfg
......
......@@ -89,7 +89,6 @@ run_getting_started_1:
stage: demo_runs
script:
- python3 demos/getting_started_1.py
- mpiexec -n 2 --bind-to none python3 demos/getting_started_1.py 2> /dev/null
artifacts:
paths:
- '*.png'
......@@ -98,7 +97,6 @@ run_getting_started_2:
stage: demo_runs
script:
- python3 demos/getting_started_2.py
- mpiexec -n 2 --bind-to none python3 demos/getting_started_2.py 2> /dev/null
artifacts:
paths:
- '*.png'
......@@ -114,7 +112,7 @@ run_getting_started_3:
run_getting_started_mf:
stage: demo_runs
script:
- python3 demos/getting_started_5_mf.py
- mpiexec -n 2 --bind-to none python3 demos/getting_started_5_mf.py
artifacts:
paths:
- '*.png'
......
......@@ -10,7 +10,7 @@ RUN apt-get update && apt-get install -y \
# Testing dependencies
python3-pytest-cov jupyter \
# Optional NIFTy dependencies
python3-mpi4py python3-matplotlib \
python3-mpi4py python3-matplotlib python3-h5py \
# more optional NIFTy dependencies
&& DUCC0_OPTIMIZATION=portable pip3 install ducc0 finufft jupyter jax jaxlib sphinx pydata-sphinx-theme \
&& rm -rf /var/lib/apt/lists/*
......
......@@ -52,6 +52,7 @@ Optional dependencies:
- [ducc0](https://gitlab.mpcdf.mpg.de/mtr/ducc) for faster FFTs, spherical
harmonic transforms, and radio interferometry gridding support
- [mpi4py](https://mpi4py.scipy.org) (for MPI-parallel execution)
- [h5py](https://www.h5py.org/) (for writing results to HDF5 files)
- [matplotlib](https://matplotlib.org/) (for field plotting)
- [jax](https://github.com/google/jax) (for implementing operators with jax)
......
......@@ -31,6 +31,14 @@ import numpy as np
import nifty8 as ift
try:
from mpi4py import MPI
comm = MPI.COMM_WORLD
master = comm.Get_rank() == 0
except ImportError:
comm = None
master = True
class SingleDomain(ift.LinearOperator):
def __init__(self, domain, target):
......@@ -134,7 +142,7 @@ def main():
for i in range(5):
# Draw new samples and minimize KL
KL = ift.SampledKLEnergy(position, H, N_samples, None)
KL = ift.SampledKLEnergy(position, H, N_samples, None, comm=comm)
KL, convergence = minimizer(KL)
position = KL.position
......@@ -155,8 +163,9 @@ def main():
label=['KL', 'Sampling', 'Newton inversion'],
title='Cumulative energies', s=[None, None, 1],
alpha=[None, 0.2, None])
plot.output(nx=3, ny=2, ysize=10, xsize=15,
name=filename.format("loop_{:02d}".format(i)))
if master:
plot.output(nx=3, ny=2, ysize=10, xsize=15,
name=filename.format("loop_{:02d}".format(i)))
# Plotting
filename_res = filename.format("results")
......@@ -176,7 +185,8 @@ def main():
pspec2.force(mock_position)],
title="Sampled Posterior Power Spectrum 2",
linewidth=[1.]*n_samples + [3., 3.])
plot.output(ny=2, nx=2, xsize=15, ysize=15, name=filename_res)
if master:
plot.output(ny=2, nx=2, xsize=15, ysize=15, name=filename_res)
print("Saved results as '{}'.".format(filename_res))
......
......@@ -205,7 +205,7 @@ class StochasticEnergyAdapter(Energy):
noise = []
sseq = random.spawn_sseq(n_samples)
from .sample_list import SampleListBase
for i in SampleListBase.indices_from_comm(n_samples, comm):
for i in SampleListBase.local_indices(n_samples, comm):
with random.Context(sseq[i]):
rnd = from_random(samdom)
noise.append(rnd)
......
......@@ -135,7 +135,7 @@ def draw_samples(position, H, minimizer, n_samples, mirror_samples, napprox=0,
utilities.check_MPI_synced_random_state(comm)
utilities.check_MPI_equality(sseq, comm)
y = None
for i in SampleListBase.indices_from_comm(len(sseq), comm):
for i in SampleListBase.local_indices(len(sseq), comm):
with random.Context(sseq[i]):
neg = mirror_samples and (i % 2 != 0)
if not neg or y is None: # we really need to draw a sample
......
......@@ -14,9 +14,13 @@
# Copyright(C) 2021 Max-Planck-Society
# Author: Philipp Arras, Philipp Frank
import os
import pickle
import re
from .. import utilities
from ..field import Field
from ..logger import logger
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
......@@ -48,12 +52,19 @@ class SampleListBase:
self._comm = comm
self._domain = makeDomain(domain)
utilities.check_MPI_equality(self._domain, comm)
global_comm, size, _, _ = utilities.get_MPI_params()
if global_comm is not None and size > 1 and comm is None:
raise ValueError("MPI is present. Please pass an MPI communicator to `SampleList`.")
@property
def n_local_samples(self):
"""int: Number of local samples."""
raise NotImplementedError
def n_samples(self):
"""Return number of samples across all MPI tasks."""
return utilities.allreduce_sum([self.n_local_samples], self.comm)
def local_item(self, i):
"""Return ith local sample."""
raise NotImplementedError
......@@ -68,13 +79,17 @@ class SampleListBase:
"""MPI communicator or None: The communicator used for the SampleListBase."""
return self._comm
@property
def mpi_master(self):
return self.comm is None or self.comm.Get_rank() == 0
@property
def domain(self):
"""DomainTuple or MultiDomain: the domain on which the samples are defined."""
return self._domain
@staticmethod
def indices_from_comm(n_samples, comm):
def local_indices(n_samples, comm=None):
"""Return range of global sample indices for local task.
This method calls `utilities.shareRange`
......@@ -94,6 +109,71 @@ class SampleListBase:
ntask, rank, _ = utilities.get_MPI_params_from_comm(comm)
return range(*utilities.shareRange(n_samples, ntask, rank))
def save_to_hdf5(self, file_name, op=None, samples=False, mean=False, std=False,
overwrite=False):
"""Write sample list to HDF5 file.
This function writes sample lists to HDF5 files that contain two
groups: `samples` and `stats`. `samples` contain the sublabels `0`,
`1`, ... that number the labels and `stats` contains the sublabels
`mean` and `standard deviation`. If `self.domain` is an instance of
:class:`~nifty8.multi_domain.MultiDomain`, these sublabels refer
themselves to subgroups. For :class:`~nifty8.field.Field`, the sublabel
refers to an HDF5 data set.
If quanitities are not requested (e.g. by setting `mean=False`), the
respective sublabels are not present in the HDF5 file.
Parameters
----------
file_name : str
File name of output hdf5 file.
op : callable or None
Callable that is applied to each item in the :class:`SampleListBase`
before it is returned. Can be an
:class:`~nifty8.operators.operator.Operator` or any other callable
that takes a :class:`~nifty8.field.Field` as an input. Default:
None.
samples : bool
If True, samples are written into hdf5 file.
mean : bool
If True, mean of samples is written into hdf5 file.
std : bool
If True, standard deviation of samples is written into hdf5 file.
overwrite : bool
If True, a potentially existing file with the same file name as
`file_name`, is overwritten.
"""
import h5py
if os.path.isfile(file_name):
if self.mpi_master and overwrite:
os.remove(file_name)
if not overwrite:
raise RuntimeError(f"File {file_name} already exists. Delete it or use "
"`overwrite=True`")
if not (samples or mean or std):
raise ValueError("Neither samples nor mean nor standard deviation shall be written.")
if self.mpi_master:
f = h5py.File(file_name, "w")
else:
f = utilities.Nop()
if samples:
grp = f.create_group("samples")
for ii, ss in enumerate(self.iterator(op)):
_field2hdf5(grp, ss, str(ii))
if mean or std:
grp = f.create_group("stats")
m, v = self.sample_stat(op)
if mean:
_field2hdf5(grp, m, "mean")
if std:
_field2hdf5(grp, v.sqrt(), "standard deviation")
f.close()
def iterator(self, op=None):
"""Return iterator over all potentially distributed samples.
......@@ -152,10 +232,6 @@ class SampleListBase:
res = [[elem[ii] for elem in res] for ii in range(n_output_elements)]
return tuple(utilities.allreduce_sum(rr, self.comm) / n for rr in res)
def n_samples(self):
"""Return number of samples across all MPI tasks."""
return utilities.allreduce_sum([self.n_local_samples], self.comm)
def sample_stat(self, op=None):
"""Compute mean and variance of samples after applying `op`.
......@@ -193,22 +269,8 @@ class SampleListBase:
"""
raise NotImplementedError
def save_helper(self, file_name_base, obj):
# Helper functions necessary because MPI communicator cannot be pickled
fname = str(file_name_base) + _mpi_file_extension(self.comm) + ".pickle"
with open(fname, "wb") as f:
pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
@staticmethod
def load_helper(file_name_base, comm):
# Helper functions necessary because MPI communicator cannot be pickled
fname = str(file_name_base) + _mpi_file_extension(comm) + ".pickle"
with open(fname, "rb") as f:
obj = pickle.load(f)
return obj
@staticmethod
def load(file_name_base, comm=None):
@classmethod
def load(cls, file_name_base, comm=None):
"""Deserialize SampleList from files on disk.
Parameters
......@@ -232,9 +294,34 @@ class SampleListBase:
"""
raise NotImplementedError
@classmethod
def _list_local_sample_files(cls, file_name_base, comm=None):
"""List all sample files that are relevant for the local task.
All sample files that correspond to `file_name_base` are searched and
selected based on which rank the local task has.
Note
----
This function makes sure that the all file numbers between 0 and the
maximal found number are present. If this is not the case, a
`RuntimeError` is raised.
"""
base_dir = os.path.abspath(os.path.dirname(file_name_base))
files = [ff for ff in os.listdir(base_dir)
if re.match(f"{file_name_base}.[0-9]+.pickle", ff) ]
if len(files) == 0:
raise RuntimeError(f"No files matching `{file_name_base}.*.pickle`")
n_samples = max(list(map(lambda x: int(x.split(".")[-2]), files))) + 1
files = [f"{file_name_base}.{ii}.pickle" for ii in cls.local_indices(n_samples, comm)]
for ff in files:
if not os.path.isfile(ff):
raise RuntimeError(f"File {ff} not found")
return files
class ResidualSampleList(SampleListBase):
def __init__(self, mean, residuals, neg, comm):
def __init__(self, mean, residuals, neg, comm=None):
"""Store samples in terms of a mean and a residual deviation thereof.
......@@ -320,13 +407,25 @@ class ResidualSampleList(SampleListBase):
return ResidualSampleList(mean, self._r, self._n, self.comm)
def save(self, file_name_base):
obj = self._m, self._r, self._n
self.save_helper(file_name_base, obj)
@staticmethod
def load(file_name_base, comm=None):
args = SampleListBase.load_helper(file_name_base, comm)
return ResidualSampleList(*args, comm=comm)
nsample = self.n_samples()
local_indices = self.local_indices(nsample, self.comm)
for ii, isample in enumerate(local_indices):
obj = [self._r[ii], self._n[ii]]
fname = _sample_file_name(file_name_base, isample)
_save_to_disk(fname, obj)
if self.mpi_master:
_save_to_disk(f"{file_name_base}.mean.pickle", self._m)
@classmethod
def load(cls, file_name_base, comm=None):
if comm is not None:
comm.Barrier()
files = cls._list_local_sample_files(file_name_base, comm)
tmp = [_load_from_disk(ff) for ff in files]
res = [aa[0] for aa in tmp]
neg = [aa[1] for aa in tmp]
mean = _load_from_disk(f"{file_name_base}.mean.pickle")
return cls(mean, res, neg, comm=comm)
class SampleList(SampleListBase):
......@@ -355,12 +454,22 @@ class SampleList(SampleListBase):
return len(self._s)
def save(self, file_name_base):
self.save_helper(file_name_base, self._s)
@staticmethod
def load(file_name_base, comm=None):
s = SampleListBase.load_helper(file_name_base, comm)
return SampleList(s, comm=comm)
nsample = self.n_samples()
local_indices = self.local_indices(nsample, self.comm)
lo = local_indices[0]
for isample in range(nsample):
if isample in local_indices:
obj = self._s[isample-lo]
fname = _sample_file_name(file_name_base, isample)
_save_to_disk(fname, obj)
@classmethod
def load(cls, file_name_base, comm=None):
if comm is not None:
comm.Barrier()
files = cls._list_local_sample_files(file_name_base, comm)
samples = [_load_from_disk(ff) for ff in files]
return cls(samples, comm=comm)
def _none_to_id(obj):
......@@ -388,18 +497,42 @@ def _bcast(obj, comm, root):
return comm.bcast(data, root=root)
def _mpi_file_extension(comm):
"""Return MPI-configuration unique string.
def _sample_file_name(file_name_base, isample):
"""Return sample-unique file name.
This string that can be used to uniquely determine the number of MPI tasks
for distributed saving of files.
This file name that can be used to uniquely write samples potentially from all MPI tasks.
Parameters
comm : MPI communicator or None
If None, an empty string is returned.
----------
file_name_base : str
User-defined first part of file name.
isample : int
Number of sample
"""
if comm is None:
return ""
ntask, rank, _ = utilities.get_MPI_params_from_comm(comm)
return f"{rank}.{ntask}"
if not isinstance(isample, int):
raise TypeError
return f"{file_name_base}.{isample}.pickle"
def _load_from_disk(file_name):
with open(file_name, "rb") as f:
obj = pickle.load(f)
return obj
def _save_to_disk(file_name, obj):
with open(file_name, "wb") as f:
pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
def _field2hdf5(file_handle, obj, name):
if not isinstance(name, str):
raise TypeError
if isinstance(obj, MultiField):
grp = file_handle.create_group(name)
for kk, fld in obj.items():
_field2hdf5(grp, fld, kk)
return
if not isinstance(obj, Field):
raise TypeError
file_handle.create_dataset(name, data=obj.val)
......@@ -502,3 +502,10 @@ def check_dtype_or_none(obj, domain=None):
s = "Need to pass floating dtype (e.g. np.float64, complex) "
s += f"or `None` to this function.\nHave recieved:\n{obj}"
raise TypeError(s)
class Nop:
def nop(*args, **kw):
return Nop()
def __getattr__(self, _):
return self.nop
......@@ -11,11 +11,13 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
# Copyright(C) 2013-2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import pytest
from glob import glob
from os import remove
def list2fixture(lst):
......@@ -29,6 +31,12 @@ def list2fixture(lst):
def setup_function():
import nifty8 as ift
ift.random.push_sseq_from_seed(42)
comm, _, _, master = ift.utilities.get_MPI_params()
if master:
for ff in glob("*.pickle") + glob("*.png") + glob("*.h5"):
remove(ff)
if comm is not None:
comm.Barrier()
def teardown_function():
......
......@@ -30,6 +30,9 @@ rank = comm.Get_rank()
master = (rank == 0)
mpi = ntask > 1
if not mpi:
comm = None
pmp = pytest.mark.parametrize
pms = pytest.mark.skipif
......@@ -38,10 +41,9 @@ pms = pytest.mark.skipif
@pmp('constants', ([], ['a'], ['b'], ['a', 'b']))
@pmp('point_estimates', ([], ['a'], ['b'], ['a', 'b']))
@pmp('mirror_samples', (False, True))
@pmp('mode', (0, 1))
@pmp('mf', (False, True))
@pmp('geo', (False, True))
def test_kl(constants, point_estimates, mirror_samples, mode, mf, geo):
def test_kl(constants, point_estimates, mirror_samples, mf, geo):
if not mf and (len(point_estimates) != 0 or len(constants) != 0):
return
dom = ift.RGSpace((12,), (2.12))
......@@ -66,7 +68,7 @@ def test_kl(constants, point_estimates, mirror_samples, mode, mf, geo):
ift.SampledKLEnergy(**args, comm=comm)
return
kl0 = ift.SampledKLEnergy(**args, comm=comm if mode == 0 else None)
kl0 = ift.SampledKLEnergy(**args, comm=comm)
if isinstance(mean0, ift.MultiField):
invariant = list(set(constants).intersection(point_estimates))
_, tmph = h.simplify_for_constant_input(mean0.extract_by_keys(invariant))
......@@ -78,7 +80,8 @@ def test_kl(constants, point_estimates, mirror_samples, mode, mf, geo):
invariant = None
samp = kl0._sample_list
ift.extra.assert_allclose(tmpmean, samp._m)
if mode == 1:
if not mpi:
samples = tuple(s for s in samp._r)
ii = len(samples)//2
slc = slice(None, ii) if rank == 0 else slice(ii, None)
......
......@@ -14,32 +14,43 @@
# Copyright(C) 2021 Max-Planck-Society
# Author: Philipp Arras
from mpi4py import MPI
import nifty8 as ift
import pytest
from mpi4py import MPI
from ..common import list2fixture, setup_function, teardown_function
comm = list2fixture([MPI.COMM_WORLD, None])
pmp = pytest.mark.parametrize
comm = [MPI.COMM_WORLD]
if MPI.COMM_WORLD.Get_size() == 1:
comm += [None]
comm = list2fixture(comm)
def _get_sample_list(communicator):
def _get_sample_list(communicator, cls):
dom = ift.makeDomain({"a": ift.UnstructuredDomain(2), "b": ift.RGSpace(12)})
samples = [ift.from_random(dom) for _ in range(3)]
return ift.SampleList(samples, communicator), samples
if cls == "SampleList":
return ift.SampleList(samples, communicator), samples
elif cls == "ResidualSampleList":
mean = ift.from_random(dom)
neg = 3*[False]
return ift.ResidualSampleList(mean, samples, neg, communicator), [mean + ss for ss in samples]
raise NotImplementedError
def test_sample_list(comm):
sl, samples = _get_sample_list(comm)
dom = sl.domain
def _get_ops(sample_list):
dom = sample_list.domain
return [None,
ift.ScalingOperator(dom, 1.),
ift.ducktape(None, dom, "a") @ ift.ScalingOperator(dom, 1.).exp()]
assert comm == sl.comm
ops = [None, ift.ScalingOperator(dom, 1.),
ift.ducktape(None, dom, "a") @ ift.ScalingOperator(dom, 1.).exp()]
def test_sample_list(comm):
sl, samples = _get_sample_list(comm, "SampleList")
assert comm == sl.comm
for op in ops:
for op in _get_ops(sl):
sc = ift.StatCalculator()
if op is None:
[sc.add(ss) for ss in samples]
......@@ -64,13 +75,36 @@ def test_sample_list(comm):
assert len(samples) <= sl.n_samples()
def test_load_and_save(comm):
sl, _ = _get_sample_list(comm)
sl.save("sample_list")
sl1 = ift.SampleList.load("sample_list", comm)
@pmp("cls", ["ResidualSampleList", "SampleList"])
def test_load_and_save(comm, cls):
if comm is None and ift.utilities.get_MPI_params()[1] > 1:
pytest.skip()
sl, _ = _get_sample_list(comm, cls)
sl.save("sl")
sl1 = getattr(ift, cls).load("sl", comm)
for s0, s1 in zip(sl.local_iterator(), sl1.local_iterator()):
ift.extra.assert_equal(s0, s1)
for s0, s1 in zip(sl.local_iterator(), sl1.local_iterator()):
ift.extra.assert_equal(s0, s1)
@pmp("cls", ["ResidualSampleList", "SampleList"])
@pmp("mean", [False, True])
@pmp("std", [False, True])
@pmp("samples", [False, True])
def test_save_to_hdf5(comm, cls, mean, std, samples):
pytest.importorskip("h5py")
if comm is None and ift.utilities.get_MPI_params()[1] > 1:
pytest.skip()
sl, _ = _get_sample_list(comm, cls)
for op in _get_ops(sl):
if not mean and not std and not samples:
with pytest.raises(ValueError):
sl.save_to_hdf5("output.h5", op, mean=mean, std=std, samples=samples)
continue
sl.save_to_hdf5("output.h5", op, mean=mean, std=std, samples=samples, overwrite=True)
if comm is not None:
comm.Barrier()