Skip to content
Snippets Groups Projects
optimize_kl.py 31.42 KiB
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2021 Max-Planck-Society
# Copyright(C) 2022 Max-Planck-Society, Philipp Arras
# Author: Philipp Arras
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

import pickle
from functools import reduce
from os import makedirs
from os.path import isdir, isfile, join
from warnings import warn

import numpy as np

from ..domain_tuple import DomainTuple
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from ..operators.counting_operator import CountingOperator
from ..operators.energy_operators import StandardHamiltonian
from ..operators.operator import Operator
from ..plot import Plot
from ..random import pop_sseq, push_sseq, spawn_sseq
from ..sugar import from_random
from ..utilities import (check_MPI_equality, check_MPI_synced_random_state,
                         get_MPI_params_from_comm)
from .energy_adapter import EnergyAdapter
from .iteration_controllers import EnergyHistory, IterationController
from .kl_energies import SampledKLEnergy
from .minimizer import Minimizer
from .sample_list import (ResidualSampleList, SampleList, SampleListBase,
                          _barrier)

_output_directory = None
_save_strategy = None


def optimize_kl(likelihood_energy,
                total_iterations,
                n_samples,
                kl_minimizer,
                sampling_iteration_controller,
                nonlinear_sampling_minimizer,
                constants=[],
                point_estimates=[],
                transitions=None,
                export_operator_outputs={},
                output_directory=None,
                initial_position=None,
                initial_index=0,
                comm=None,
                inspect_callback=None,
                terminate_callback=None,
                plot_energy_history=True,
                plot_minisanity_history=True,
                save_strategy="last",
                return_final_position=False,
                resume=False,
                sanity_checks=True,
                dry_run=False,
                fresh_stochasticity=True):
    """Provide potentially useful interface for standard KL minimization.

    The parameters `likelihood_energy`, `kl_minimizer`,
    `sampling_iteration_controller`, `nonlinear_sampling_minimizer`, `constants`,
    `point_estimates` and `comm` also accept a function as input that takes the
    index of the global iteration as input and returns the respective value that
    should be used for that iteration.

    High-level pseudo code of the algorithm, that is implemented by `optimize_kl`:

    .. code-block:: none

        for ii in range(initial_index, total_iterations):
            samples = Draw `n_samples` approximate samples(position,
                                                           likelihood_energy
                                                           sampling_iteration_controller,
                                                           nonlinear_sampling_minimizer,
                                                           point_estimates)
            position, samples = Optimize approximate KL(likelihood_energy, samples) with `kl_minimizer`(constants)
            Save intermediate results(samples)

    Parameters
    ----------
    likelihood_energy : Operator or callable
        Likelihood energy shall be used for inference. It is assumed that the
        definition space of this energy contains parameters that are a-priori
        standard normal distributed. It needs to have a MultiDomain as Domain.
    total_iterations : int
        Number of resampling loops.
    n_samples : int or callable
        Number of samples used to sample Kullback-Leibler divergence. 0
        corresponds to maximum-a-posteriori.
    kl_minimizer : Minimizer or callable
        Controls the minimizer for the KL optimization.
    sampling_iteration_controller : IterationController or None or callable
        Controls the conjugate gradient for inverting the posterior metric.  If
        `None`, approximate posterior samples cannot be drawn. It is only
        suited for maximum-a-posteriori solutions.
    nonlinear_sampling_minimizer : Minimizer or None or callable
        Controls the minimizer for the non-linear geometric sampling to
        approximate the KL. Can be either None (then the MGVI algorithm is used
        instead of geoVI) or a Minimizer.
    constants : list or callable
        List of parameter keys that are kept constant during optimization.
        Default is no constants.
    point_estimates : list or callable
        List of parameter keys for which no samples are drawn, but that are
        (possibly) optimized for, corresponding to point estimates of these.
        Default is to draw samples for the complete domain.
    transitions : callable or None
        Function that takes the global iteration index and returns a function
        that is applied at the beginning of each iteration to the current
        ResidualSampleList. This can be used if parts of the likelihood_energy
        are replaced and thereby have a different domain. If None is passed or
        returned in one iteration, no transition will be applied. Default is
        None.
    export_operator_outputs : dict
        Dictionary of operators that are exported during the minimization. The
        key contains a string that serves as identifier. The value of the
        dictionary is an operator.
    output_directory : str or None
        Directory in which all output files are saved. If None, no output is
        stored.  Default: None.
    initial_position : :class:`nifty8.field.Field`, :class:`nifty8.multi_field.MultiField` or None
        Position in the definition space of `likelihood_energy` from which the
        optimization is started. If `None`, it starts at a random, normal
        distributed position with standard deviation 0.1. Default: None.
    initial_index : int
        Initial index that is used to enumerate the output files. May be used
        if `optimize_kl` is called multiple times. Default: 0.
    comm : MPI communicator or None
        MPI communicator for distributing samples over MPI tasks. If `None`,
        the samples are not distributed. Default: None.
        To enable transferring large objects (>2 GiB) via MPI, wrap the `comm`
        object with `mpi4py.utils.pkl5.Intracomm` before passing it.
    inspect_callback : callable or None
        Function that is called after every global iteration. It can be either a
        function with one argument (then the latest sample list is passed), a
        function with two arguments (in which case the latest sample list and
        the global iteration index are passed). Default: None.
    terminate_callback : callable or None
        Function that is called after every global iteration and after
        `inspect_callback` if present.  It can be either None or a function
        that takes the global iteration index as input and returns a boolean.
        If the return value is true, the global loop in `optimize_kl` is
        terminated. Default: None.
    plot_energy_history : bool
        Determine if the KLEnergy values shall be plotted or not. Default: True.
    plot_minisanity_history : bool
        Determine if the reduced chi-square values computed by minisanity shall
        be plotted or not. Default: True.
    save_strategy : str
        If "last", the samples are saved after each global iteration and
        overwrite older saved samples so that only the latest ones are kept.
        If "all", all intermediate samples are written to disk and the filename
        contains the `global_iteration` so that nothing is overwritten.
        `save_strategy` is only applicable if `output_directory` is not None.
        Default: "last".
    return_final_position : bool
        Determine if the final position of the minimization shall be return.
        May be useful to feed it as `initial_position` into another
        `optimize_kl` call. Default: False.
    resume : bool
        Resume partially run optimization. If `True` and `output_directory`
        contains `last_finished_iteration`, `initial_index` and
        `initial_position` are ignored and read from the output directory
        instead. If `last_finished_iteration` is not a file, the value of
        `initial_position` is used instead. Default: False.
    sanity_checks : bool
        Some sanity checks that are evaluated at the beginning. They are
        potentially expensive because all likelihoods have to be instantiated
        multiple times. Default: True.
    dry_run : bool
        Skips all expensive optimizations. Can be used to check that all domains
        fit together. Default: False.
    fresh_stochasticity : bool or callable
        Controls whether new randomness is used when drawing samples. If set to
        a boolean, the same value is used across all iterations (note that it
        cannot be False initially, since the first iteration requires fresh
        samples). Alternatively, a callable can be provided that takes the
        current iteration index as input and returns a boolean, allowing for
        dynamic control over when new stochasticity is introduced. Some users
        choose to set `fresh_stochasticity` to False toward the end of an
        `optimize_kl` run to reduce stochastic fluctuations and artificially
        increase convergence. Default: True.

    Returns
    -------
    sl : SampleList

    mean : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField` (optional)

    Note
    ----
    If `h5py` is available, the output of all plotted operators is saved as hdf5
    file as well.  If `astropy` is available, the mean and standard deviation of
    all plotted operators, that have a single 2d-RGSpace as target, are exported
    as FITS files.

    Note
    ----
    This function comes with some MPI support. Generally, with the help of MPI
    samples are distributed over tasks.
    """
    from ..sugar import full, makeDomain
    from ..utilities import myassert
    from .descent_minimizers import DescentMinimizer

    if not isinstance(export_operator_outputs, dict):
        raise TypeError
    if len(set(["pickle"]) & set(export_operator_outputs.keys())) != 0:
        raise ValueError("The key `pickle` in `export_operator_outputs` is reserved.")
    if not isinstance(initial_index, int):
        raise TypeError
    if save_strategy not in ["all", "last"]:
        raise ValueError(f"Save strategy '{save_strategy}' not supported.")
    if output_directory is None and resume:
        raise ValueError("Can only resume minimization if output_directory is not None")

    likelihood_energy = _make_callable(likelihood_energy)
    kl_minimizer = _make_callable(kl_minimizer)
    sampling_iteration_controller = _make_callable(sampling_iteration_controller)
    nonlinear_sampling_minimizer = _make_callable(nonlinear_sampling_minimizer)
    constants = _make_callable(constants)
    point_estimates = _make_callable(point_estimates)
    transitions = _make_callable(transitions)
    n_samples = _make_callable(n_samples)
    comm = _make_callable(comm)
    inspect_callback = _make_callable(inspect_callback)
    if terminate_callback is None:
        terminate_callback = _make_callable(False)
    fresh_stochasticity = _make_callable(fresh_stochasticity)

    if initial_position is None:
        mean = full(makeDomain({}), 0.)
    else:
        mean = initial_position
        del(initial_position)
    sl = _single_value_sample_list(mean, comm=comm(initial_index))
    energy_history = EnergyHistory()

    # Sanity check of input
    if initial_index >= total_iterations:
        raise ValueError("Initial index is bigger than total iterations: "
                        f"{initial_index} >= {total_iterations}")
    if _number_of_arguments(transitions) != 1:
        raise ValueError(f"Transition takes 1 argument but {_number_of_arguments(transitions)} were given.")
    if _number_of_arguments(inspect_callback) not in [1, 2]:
        raise ValueError(f"Inspect callback takes either 1 or 2 arguments but {_number_of_arguments(inspect_callback)}"
                         f"were given.")
    if _number_of_arguments(terminate_callback) !=1:
        raise ValueError(f"Terminate callback takes 1 argument but {_number_of_arguments(terminate_callback)} were given.")
    if not likelihood_energy(initial_index).target is DomainTuple.scalar_domain():
        raise TypeError

    if sanity_checks:  # Potentially expensive sanity checks
        for iglobal in range(initial_index, total_iterations):
            for (obj, cls) in [(likelihood_energy, Operator), (kl_minimizer, DescentMinimizer),
                               (nonlinear_sampling_minimizer, (DescentMinimizer, type(None))),
                               (constants, (list, tuple)), (point_estimates, (list, tuple)),
                               (n_samples, int)]:
                if not isinstance(obj(iglobal), cls):
                    s = f"{obj(iglobal)} is not instance of {cls} but rather {type(obj(iglobal))}"
                    raise TypeError(s)

            if sampling_iteration_controller(iglobal) is None:
                myassert(n_samples(iglobal) == 0)
            else:
                myassert(isinstance(sampling_iteration_controller(iglobal), IterationController))
            myassert(likelihood_energy(iglobal).target is DomainTuple.scalar_domain())
            if not comm(iglobal) is None:
                try:
                    import mpi4py
                    myassert(isinstance(comm(iglobal), mpi4py.MPI.Intracomm))
                except ImportError:
                    pass
    # /Sanity check of input

    if output_directory is not None:
        global _output_directory
        global _save_strategy
        _output_directory = output_directory
        _save_strategy = save_strategy

        # Create all necessary subfolders
        if _MPI_master(comm(initial_index)):
            makedirs(output_directory, exist_ok=True)
            subfolders = ["pickle"] + list(export_operator_outputs.keys())
            if plot_energy_history:
                subfolders += ["energy_history"]
            if plot_minisanity_history:
                subfolders += ["minisanity_history"]
            for subfolder in subfolders:
                makedirs(join(output_directory, subfolder), exist_ok=True)

        # Resume
        lfile = join(output_directory, "last_finished_iteration")
        if resume and isfile(lfile):
            with open(lfile) as f:
                last_finished_index = int(f.read())
            initial_index = last_finished_index + 1
            fname = _file_name_by_strategy(last_finished_index)
            fname = reduce(join, [output_directory, "pickle", fname])
            if isfile(fname + ".mean.pickle"):
                mean = ResidualSampleList.load_mean(fname)
                sl = ResidualSampleList.load(fname, comm=comm(last_finished_index))
            else:
                sl = SampleList.load(fname)
                # no `comm` object passed to `load` since we expect to find only one sample
                # (a point estimate), which has to be available in all threads.
                myassert(sl.n_samples == 1)
                mean = sl.local_item(0)

            if initial_index == total_iterations:
                return (sl, mean) if return_final_position else sl

            _load_random_state()
            energy_history = _pickle_load_values(last_finished_index, 'energy_history')

        # No resume
        else:
            check_MPI_synced_random_state(comm(iglobal))
            if _MPI_master(comm(initial_index)):
                _save_random_state()

    # Prepare stochasicity contexts
    sseqs = spawn_sseq(total_iterations)
    for iglobal in range(total_iterations):
        if not fresh_stochasticity(iglobal):
            if iglobal == 0:
                raise ValueError("fresh_stochasticity needs to be True for initial iteration")
            sseq_dup = np.random.SeedSequence(sseqs[iglobal-1].entropy,
                                              spawn_key=sseqs[iglobal-1].spawn_key,
                                              pool_size=sseqs[iglobal-1].pool_size)
            sseqs[iglobal] = sseq_dup

    for iglobal in range(initial_index, total_iterations):
        push_sseq(sseqs[iglobal])
        lh = likelihood_energy(iglobal)
        if not isinstance(lh.domain, MultiDomain):
            raise TypeError("Domain of likelihood_energy needs to be a MultiDomain, "
                            f"got\n{lh.domain}")

        # Transform mean
        t = transitions(iglobal)
        mean = mean if t is None else t(sl)
        mean = _normal_initialize(mean, lh.domain)
        # /Transform mean

        count = CountingOperator(lh.domain)
        ham = StandardHamiltonian(lh @ count, sampling_iteration_controller(iglobal))
        minimizer = kl_minimizer(iglobal)
        mean_iter = mean.extract(ham.domain)

        if dry_run:
            from ..logger import logger
            logger.info(f"Iteration {iglobal} checked")
            continue

        # TODO Distributing the domain of the likelihood is not supported (yet)
        check_MPI_synced_random_state(comm(iglobal))
        check_MPI_equality(lh.domain, comm(iglobal))
        check_MPI_equality(mean.domain, comm(iglobal))
        check_MPI_equality(mean, comm(iglobal), hash=True)

        # free memory of old samples before drawing new ones
        sl = None

        if n_samples(iglobal) == 0:
            e = EnergyAdapter(mean_iter, ham, constants=constants(iglobal),
                              want_metric=_want_metric(minimizer))
            if comm(iglobal) is None:
                e, _ = minimizer(e)
                mean = MultiField.union([mean, e.position])
                sl = SampleList([mean])
                energy_history.append((iglobal, e.value))
            else:
                warn("Have detected MPI communicator for optimizing Hamiltonian. Will use only "
                     "the rank0 task and communicate the result afterwards to all other tasks.")
                if _MPI_master(comm(iglobal)):
                    e, _ = minimizer(e)
                    energy_history.append((iglobal, e.value))
                    mean = MultiField.union([mean, e.position])
                else:
                    mean = None
                _barrier(comm(iglobal))
                mean = comm(iglobal).bcast(mean, root=0)
                sl = _single_value_sample_list(mean, comm(iglobal))
        else:
            e = SampledKLEnergy(
                mean_iter,
                ham,
                n_samples(iglobal),
                nonlinear_sampling_minimizer(iglobal),
                comm=comm(iglobal),
                constants=constants(iglobal),
                point_estimates=point_estimates(iglobal))
            e, _ = minimizer(e)
            mean = MultiField.union([mean, e.position])
            sl = e.samples.at(mean)
            energy_history.append((iglobal, e.value))

        if output_directory is not None:
            _export_operators(iglobal, export_operator_outputs, sl, comm(iglobal))
            sl.save(join(output_directory, "pickle/") + _file_name_by_strategy(iglobal),
                    overwrite=True)

            if _MPI_master(comm(iglobal)):
                with open(join(output_directory, "last_finished_iteration"), "w") as f:
                    f.write(str(iglobal))
                _pickle_save_values(iglobal, 'energy_history', energy_history)
                if plot_energy_history:
                    _plot_energy_history(iglobal, energy_history)
        _barrier(comm(iglobal))

        _minisanity(lh, iglobal, sl, comm, plot_minisanity_history)
        _barrier(comm(iglobal))

        _counting_report(count, iglobal, comm)

        _handle_inspect_callback(inspect_callback, sl, iglobal)
        _barrier(comm(iglobal))

        if _handle_terminate_callback(terminate_callback, iglobal, comm):
            break
        _barrier(comm(iglobal))

        lh = None
        pop_sseq()

    return (sl, mean) if return_final_position else sl


def _file_name(name, index, prefix=""):
    op_direc = join(_output_directory, name)
    return join(op_direc, f"{prefix}{index:03d}.png")


def _file_name_by_strategy(iglobal, save_strategy='global_strategy'):
    if save_strategy == 'global_strategy':
        save_strategy = _save_strategy
    if save_strategy == "all":
        return f"iteration_{iglobal}"
    elif save_strategy == "last":
        return "last"
    raise RuntimeError


def _save_random_state():
    from ..random import getState
    file_name = join(_output_directory, "pickle/nifty_random_state")
    with open(file_name, "wb") as f:
        f.write(getState())


def _load_random_state():
    from ..random import setState
    file_name = join(_output_directory, "pickle/nifty_random_state")
    with open(file_name, "rb") as f:
        setState(f.read())


def _pickle_save_values(index, name, val):
    file_name = join(_output_directory, f"pickle/{name}_")
    file_name += _file_name_by_strategy(index)
    with open(file_name, "wb") as f:
        pickle.dump(val, f)


def _pickle_load_values(index, name):
    file_name = join(_output_directory, f"pickle/{name}_")
    file_name += _file_name_by_strategy(index)
    with open(file_name, "rb") as f:
        val = pickle.load(f)
    return val


def _export_operators(index, export_operator_outputs, sample_list, comm):
    if not isinstance(export_operator_outputs, dict):
        raise TypeError
    if not isdir(_output_directory):
        raise RuntimeError(f"{_output_directory} does not exist")
    if not isinstance(sample_list, SampleListBase):
        raise TypeError

    try:
        import h5py
    except ImportError:
        h5py = False

    try:
        import astropy
    except ImportError:
        astropy = False

    for name, op in export_operator_outputs.items():
        if not _is_subdomain(op.domain, sample_list.domain):
            continue

        op_direc = join(_output_directory, name)
        if sample_list.n_samples > 1:
            cfg = {"samples": True, "mean": True, "std": True}
        else:
            cfg = {"samples": True, "mean": False, "std": False}

        if h5py:
            file_name = join(op_direc, _file_name_by_strategy(index) + ".hdf5")
            sample_list.save_to_hdf5(file_name, op=op, overwrite=True, **cfg)

        if astropy:
            try:
                file_name_base = join(op_direc, _file_name_by_strategy(index))
                sample_list.save_to_fits(file_name_base, op=op, overwrite=True, **cfg)
            except ValueError:
                pass


def _plot_energy_history(index, energy_history):
    import matplotlib.pyplot as plt
    fname = join(_output_directory, 'energy_history',
                 '{}_' + _file_name_by_strategy(index) + '.png')

    E = np.array(energy_history.energy_values)

    # energy value plot
    p = Plot()
    p.add(energy_history, skip_timestamp_conversion=True, xlabel='iteration',
          ylabel=r'E', yscale='log' if (E > 0.).all() else 'linear')
    p.output(title='energy history',
             name=fname.format('energy_history'),
             xsize=max(0.8 + 0.175 * index, 6.4),
             ysize=4.8,
             dpi=125)

    # energy change plot
    if index > 0:
        xsize = max(0.8 + 0.175 * (index - 1), 6.4)
        plt.figure(figsize=(xsize, 4.8))
        ts = np.array(energy_history.time_stamps[1:])
        dE = E[1:] - E[:-1]
        idx_pos = (dE > 0)
        idx_neg = (dE < 0)
        plt.plot(ts[idx_pos], dE[idx_pos], '^', color='red', label='positive')
        plt.plot(ts[idx_neg], -dE[idx_neg], 'v', color='green', label='negative')
        plt.yscale('log')
        plt.title('energy change w.r.t. previous step')
        plt.xlabel('iteration')
        plt.ylabel(r'$\Delta\,$E')
        plt.legend()
        plt.savefig(fname.format('energy_change_history'),
                    bbox_inches="tight", dpi=125)
        plt.close()


def _append_key(s, key):
    if key == "":
        return s
    return f"{s} ({key})"


def _minisanity(likelihood_energy, iglobal, sl, comm, plot_minisanity_history):
    from ..extra import minisanity

    s, ms_val = minisanity(likelihood_energy, sl, terminal_colors=False,
                           return_values=True)
    check_MPI_equality(ms_val, comm(iglobal))
    _report_to_logger_and_file(s, "minisanity.txt", iglobal, comm, True, True,
                               True)

    if _MPI_master(comm(iglobal)) and _output_directory is not None:
        # load/create minisanity_history object
        value_type_keys = ['redchisq', 'scmean']
        category_keys = ['data_residuals', 'latent_variables']

        if iglobal == 0:
            mh = {tk: {ck: {} for ck in category_keys} for tk in value_type_keys}
        else:
            mh = _pickle_load_values(iglobal - 1, 'minisanity_history')

        key_set_mh = {}
        key_set_msval = {}
        for ck in category_keys:
            key_set_mh[ck] = set(mh['redchisq'][ck].keys())
            key_set_msval[ck] = set(ms_val['redchisq'][ck].keys())

        for tk in value_type_keys:
            for ck in category_keys:
                # all keys not yet in minisanity history
                for ek in key_set_msval[ck] - key_set_mh[ck]:
                    v = ms_val[tk][ck][ek]
                    mh[tk][ck][ek] = {}
                    mh[tk][ck][ek]['index'] = [iglobal, ]
                    mh[tk][ck][ek]['mean'] = [v['mean'], ]
                    mh[tk][ck][ek]['std'] = [v['std'], ]
                # all keys already present in minisanity history
                for ek in key_set_msval[ck] & key_set_mh[ck]:
                    v = ms_val[tk][ck][ek]
                    mh[tk][ck][ek]['index'].append(iglobal)
                    mh[tk][ck][ek]['mean'].append(v['mean'])
                    mh[tk][ck][ek]['std'].append(v['std'])

        _pickle_save_values(iglobal, 'minisanity_history', mh)

        if plot_minisanity_history:
            _plot_minisanity_history(iglobal, mh)


def _plot_minisanity_history(index, minisanity_history):
    import matplotlib.pyplot as plt
    from matplotlib.cm import plasma

    mhrcs = minisanity_history['redchisq']

    labels = []
    vals = []
    idxs = []

    n_dr = 0
    for kk in mhrcs['data_residuals'].keys():
        labels.append(f'residuals: {n_dr}')
        vals.append(mhrcs['data_residuals'][kk]['mean'])
        idxs.append(mhrcs['data_residuals'][kk]['index'])
        n_dr += 1

    n_lv = 0
    for kk in mhrcs['latent_variables'].keys():
        labels.append(f'latent: {kk}')
        vals.append(mhrcs['latent_variables'][kk]['mean'])
        idxs.append(mhrcs['latent_variables'][kk]['index'])
        n_lv += 1

    n_tot = n_dr + n_lv

    colors = [plasma(x) for x in np.linspace(0, 0.95, n_tot)]

    linestyles = ['-'] * n_tot
    for ii in range(1, n_tot, 3):
        linestyles[ii] = ':'
    for ii in range(2, n_tot, 3):
        linestyles[ii] = '--'

    vals = [np.array(v) for v in vals]
    idxs = [np.array(i) for i in idxs]

    delta_idx = np.max([np.max(i) for i in idxs]) - np.min([np.min(i) for i in idxs])
    xsize = max(0.8 + 0.2 * delta_idx, 6.4)
    plt.figure(figsize=(xsize, 4.8))

    for i in range(n_tot):
        plt.plot(idxs[i], vals[i], label=labels[i], color=colors[i], marker='.',
                 linestyle=linestyles[i])

    xlim = plt.xlim()
    plt.hlines(1., xlim[0], xlim[1], linestyle='dashed', color='black', linewidth=2, zorder=-1)
    plt.xlim(*xlim)
    plt.yscale('log')

    plt.title(r'reduced $\chi^2$ values')
    plt.xlabel('iteration')
    plt.ylabel(r'red. $\chi^2$')
    plt.legend(bbox_to_anchor=(1.04, 0.5),
               loc='center left',
               borderaxespad=0,
               ncol=int(np.ceil(n_tot / 20)))
    plt.savefig(join(_output_directory, 'minisanity_history',
                     'minisanity_history_' + _file_name_by_strategy(index) + '.png'),
                bbox_inches="tight", dpi=250)
    plt.close()


def _counting_report(count, iglobal, comm):
    _report_to_logger_and_file(count.report(), "counting_report.txt", iglobal,
                               comm, _output_directory is None, True, False)


def _report_to_logger_and_file(report, file_name, iglobal, comm, to_logger,
                               to_file, only_master):
    from datetime import datetime

    from ..logger import logger
    from ..utilities import allreduce_sum

    intro = f"Finished index: {iglobal}\nCurrent datetime: {datetime.now()}\n"

    if not only_master:
        report = allreduce_sum([[report]], comm(iglobal))
        report = [f"Task {ii}\n{rr}" for ii, rr in enumerate(report)]
        report = "\n".join(report)

    if _MPI_master(comm(iglobal)):
        if to_logger:
            logger.info(report)
        if _output_directory is not None and to_file:
            with open(join(_output_directory, file_name), "a", encoding="utf-8") as f:
                f.write(intro + report + "\n\n")


def _handle_inspect_callback(inspect_callback, sl, iglobal):
    if _number_of_arguments(inspect_callback) == 1:
        inp = (sl, )
    elif _number_of_arguments(inspect_callback) == 2:
        inp = (sl, iglobal)
    inspect_callback(*inp)


def _handle_terminate_callback(terminate_callback, iglobal, comm):
    terminate = terminate_callback(iglobal)
    check_MPI_equality(terminate, comm(iglobal))
    return terminate


def _MPI_master(comm):
    return get_MPI_params_from_comm(comm)[2]


def _make_callable(obj):
    if callable(obj) and not (isinstance(obj, Minimizer) or isinstance(obj, IterationController) or
                              isinstance(obj, Operator)):
        return obj
    else:
        return lambda x: obj


def _want_metric(mini):
    from .descent_minimizers import L_BFGS, VL_BFGS, SteepestDescent

    # TODO Make this a property of the minimizer?
    if isinstance(mini, (SteepestDescent, L_BFGS, VL_BFGS)):
        return False
    return True


def _number_of_arguments(func):
    from inspect import signature
    return len(signature(func).parameters)


def _is_subdomain(sub_domain, total_domain):
    if not isinstance(sub_domain, (MultiDomain, DomainTuple)):
        raise TypeError
    if isinstance(sub_domain, DomainTuple):
        return sub_domain == total_domain
    return all(kk in total_domain.keys() and vv == total_domain[kk]
               for kk, vv in sub_domain.items())


def _normal_initialize(mf, domain, std=0.1):
    from ..sugar import makeDomain
    from ..utilities import myassert

    if MultiDomain.union([domain, mf.domain]) != domain:
        raise RuntimeError(f"Domain of MultiField and final domain are not compatible\n"
                           f"MultiField domain:\n{mf.domain}\n"
                           f"Final domain:\n{domain}")
    diff_keys = set(domain.keys()) - set(mf.domain.keys())
    diff_domain = makeDomain({kk: domain[kk] for kk in diff_keys})
    diff_fld = from_random(diff_domain, std=std)
    res = mf.unite(diff_fld)
    myassert(res.domain is domain)
    return res


def _single_value_sample_list(fld, comm):
    check_MPI_equality(fld, comm, hash=True)
    if _MPI_master(comm):
        sl = SampleList([fld], comm=comm, domain=fld.domain)
    else:
        sl = SampleList([], comm=comm, domain=fld.domain)
    return sl