Commit c4ffab5b authored by Philipp Arras's avatar Philipp Arras
Browse files

Start implementing simple minimization

parent 933109a0
......@@ -16,61 +16,93 @@
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from .extra import minisanity
from .minimization.descent_minimizers import NewtonCG
from .utilities import assert_rngs_synchronized
from .minimization.energy_adapter import EnergyAdapter
from .minimization.iteration_controllers import (
AbsDeltaEnergyController,
GradientNormController,
IterationController,
)
from .minimization.metric_gaussian_kl import MetricGaussianKL
from .operators.energy_operators import StandardHamiltonian
from .sugar import from_random
def minimize_mgkl(
lh,
name_outputdirectory,
nsamples,
nglobaliter,
niter_newton,
lst_of_plottable_operators=None,
likelihood,
n_samples,
global_iterations,
newton_convergence,
sampling_convergence=(0.1, 300),
plottable_operators=None,
output_directory=None,
initial_position=None,
sampling_convergence=0.1,
sampling_niter=300,
comm=None,
):
"""Provide simplified interface for MGVI minimization.
Parameters
----------
n_samples : int
Number of samples used to sample Kullback-Leibler divergence.
newton_convergence : IterationController or (float, int)
Either IterationController or maxmimum number of iterations or convergence criterium and maxmimum number of iterations.
sampling_convergence : IterationController or (float, int)
FIXME
Returns
-------
latent_mean : Field or MultiField
kl : Energy
"""
# AbsDeltaEngeryController(0.05) -> guter default
# Man muss: <op(samples)> und nicht op(<samples>).
# initial_opisition = 0.1 * from_random
raise NotImplementedError
return latent_mean
mean = initial_position
if mean is None:
mean = 0.1 * from_random(likelihood.domain)
ham = StandardHamiltonian(likelihood, _get_controller(sampling_convergence))
minimizer = NewtonCG(_get_controller(newton_convergence))
# TODO Energy histories
if n_samples == 0:
if comm is not None:
raise ValueError("MPI-parallel evaluation of Hamiltonian not available.")
e = EnergyAdapter(mean, ham)
for _ in range(global_iterations):
e, _ = minimizer(e)
# FIXME minisanity(likelihood, e.position)
plot_operator_list(output_directory, plottable_operators, e.position)
else:
for _ in range(global_iterations):
assert_rngs_synchronized(comm)
e = MetricGaussianKL.make(mean, ham, n_samples, True, comm=comm)
e, _ = minimizer(e)
mean = e.position
# FIXME minisanity(likelihood, mean, samples)
plot_operator_list(
output_directory, plottable_operators, e.position, e.samples,
)
# Save latent mean, residual samples (only of last iteration)
# Save posterior samples for plotted operators (only of last iteration)
return e
# nsamples = 0 -> MAP
# plots -> mean und sd
# outputdirectory
# -> plots (prefix: unixtime)
# -> latent mean, residual samples (only of last iteration)
# -> posterior samples for plotted operators (only of last iteration)
# with mpi support
def minimize_map(
lh,
name_outputdirectory,
nglobaliter,
niter_newton,
lst_of_plottable_operators=None,
initial_position=None,
):
"""Provide simplified interface for maximum-a-posterior minimization.
Parameters
----------
def plot_operator_list(output_directory, plottable_operators, position, samples=None):
print("WARN: Not Implemented yet")
# Man muss: <op(samples)> und nicht op(<samples>).
# plots -> mean und sd
# outputdirectory
# -> plots (prefix: unixtime)
Returns
-------
position : Field or MultiField
"""
raise NotImplementedError
return position
def _get_controller(criterium):
if hasattr(criterium, len) and len(criterium) == 2:
if criterium[0] is None:
criterium = GradientNormController(iteration_limit=criterium)
else:
criterium = AbsDeltaEnergyController(
criterium[0], iteration_limit=criterium[1]
)
if not isinstance(criterium, IterationController):
raise TypeError
return criterium
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment