Commit 9abb8a34 authored by Philipp Arras's avatar Philipp Arras
Browse files

Merge branch 'energy_histories' into 'NIFTy_6'

Add energy logging

See merge request !470
parents d659da88 3a8d202d
Pipeline #75221 passed with stages
in 8 minutes and 26 seconds
......@@ -114,7 +114,9 @@ if __name__ == '__main__':
ic_newton = ift.AbsDeltaEnergyController(name='Newton',
deltaE=0.01,
iteration_limit=35)
minimizer = ift.NewtonCG(ic_newton)
ic_sampling.enable_logging()
ic_newton.enable_logging()
minimizer = ift.NewtonCG(ic_newton, activate_logging=True)
## number of samples used to estimate the KL
N_samples = 20
......@@ -143,10 +145,15 @@ if __name__ == '__main__':
plot.add([A2.force(KL.position),
A2.force(mock_position)],
title="power2")
plot.output(nx=2,
plot.add((ic_newton.history, ic_sampling.history,
minimizer.inversion_history),
label=['KL', 'Sampling', 'Newton inversion'],
title='Cumulative energies', s=[None, None, 1],
alpha=[None, 0.2, None])
plot.output(nx=3,
ny=2,
ysize=10,
xsize=10,
xsize=15,
name=filename.format("loop_{:02d}".format(i)))
# Done, draw posterior samples
......
......@@ -166,7 +166,8 @@ class NewtonCG(DescentMinimizer):
"""
def __init__(self, controller, napprox=0, line_searcher=None, name=None,
nreset=20, max_cg_iterations=200, energy_reduction_factor=0.1):
nreset=20, max_cg_iterations=200, energy_reduction_factor=0.1,
activate_logging=False):
if line_searcher is None:
line_searcher = LineSearch(preferred_initial_step_size=1.)
super(NewtonCG, self).__init__(controller=controller,
......@@ -176,6 +177,8 @@ class NewtonCG(DescentMinimizer):
self._nreset = nreset
self._max_cg_iterations = max_cg_iterations
self._alpha = energy_reduction_factor
from .iteration_controllers import EnergyHistory
self._history = EnergyHistory() if activate_logging else None
def get_descent_direction(self, energy, old_value=None):
if old_value is None:
......@@ -184,14 +187,22 @@ class NewtonCG(DescentMinimizer):
ediff = self._alpha*(old_value-energy.value)
ic = AbsDeltaEnergyController(
ediff, iteration_limit=self._max_cg_iterations, name=self._name)
if self._history is not None:
ic.enable_logging()
e = QuadraticEnergy(0*energy.position, energy.metric, energy.gradient)
p = None
if self._napprox > 1:
met = energy.metric
p = makeOp(approximation2endo(met, self._napprox)).inverse
e, conv = ConjugateGradient(ic, nreset=self._nreset)(e, p)
if self._history is not None:
self._history += ic.history
return -e.position
@property
def inversion_history(self):
return self._history
class L_BFGS(DescentMinimizer):
def __init__(self, controller, line_searcher=LineSearch(),
......
......@@ -11,10 +11,13 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import functools
from time import time
import numpy as np
from ..logger import logger
......@@ -37,10 +40,17 @@ class IterationController(metaclass=NiftyMeta):
class; the implementer has full flexibility to use whichever criteria are
appropriate for a particular problem - as long as they can be computed from
the information passed to the controller during the iteration process.
For analyzing minimization procedures IterationControllers can log energy
values together with the respective time stamps. In order to activate this
feature `activate_logging()` needs to be called.
"""
CONVERGED, CONTINUE, ERROR = list(range(3))
def __init__(self):
self._history = None
def start(self, energy):
"""Starts the iteration.
......@@ -69,6 +79,68 @@ class IterationController(metaclass=NiftyMeta):
"""
raise NotImplementedError
def enable_logging(self):
"""Enables the logging functionality. If the log has been populated
before, it stays as it is."""
if self._history is None:
self._history = EnergyHistory()
def disable_logging(self):
"""Disables the logging functionality. If the log has been populated
before, it is dropped."""
self._history = None
@property
def history(self):
return self._history
class EnergyHistory(object):
def __init__(self):
self._lst = []
def append(self, x):
if len(x) != 2:
raise ValueError
self._lst.append((float(x[0]), float(x[1])))
def reset(self):
self._lst = []
@property
def time_stamps(self):
return [x for x, _ in self._lst]
@property
def energy_values(self):
return [x for _, x in self._lst]
def __add__(self, other):
if not isinstance(other, EnergyHistory):
return NotImplemented
res = EnergyHistory()
res._lst = self._lst + other._lst
return res
def __iadd__(self, other):
if not isinstance(other, EnergyHistory):
return NotImplemented
self._lst += other._lst
return self
def __len__(self):
return len(self._lst)
def append_history(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
hist = args[0].history
if isinstance(hist, EnergyHistory):
hist.append((time(), args[1].value))
return func(*args, **kwargs)
return wrapper
class GradientNormController(IterationController):
"""An iteration controller checking (mainly) the L2 gradient norm.
......@@ -94,12 +166,14 @@ class GradientNormController(IterationController):
def __init__(self, tol_abs_gradnorm=None, tol_rel_gradnorm=None,
convergence_level=1, iteration_limit=None, name=None):
super(GradientNormController, self).__init__()
self._tol_abs_gradnorm = tol_abs_gradnorm
self._tol_rel_gradnorm = tol_rel_gradnorm
self._convergence_level = convergence_level
self._iteration_limit = iteration_limit
self._name = name
@append_history
def start(self, energy):
self._itcount = -1
self._ccount = 0
......@@ -108,6 +182,7 @@ class GradientNormController(IterationController):
* energy.gradient_norm
return self.check(energy)
@append_history
def check(self, energy):
self._itcount += 1
......@@ -163,16 +238,19 @@ class GradInfNormController(IterationController):
def __init__(self, tol, convergence_level=1, iteration_limit=None,
name=None):
super(GradInfNormController, self).__init__()
self._tol = tol
self._convergence_level = convergence_level
self._iteration_limit = iteration_limit
self._name = name
@append_history
def start(self, energy):
self._itcount = -1
self._ccount = 0
return self.check(energy)
@append_history
def check(self, energy):
self._itcount += 1
......@@ -224,17 +302,20 @@ class DeltaEnergyController(IterationController):
def __init__(self, tol_rel_deltaE, convergence_level=1,
iteration_limit=None, name=None):
super(DeltaEnergyController, self).__init__()
self._tol_rel_deltaE = tol_rel_deltaE
self._convergence_level = convergence_level
self._iteration_limit = iteration_limit
self._name = name
@append_history
def start(self, energy):
self._itcount = -1
self._ccount = 0
self._Eold = 0.
return self.check(energy)
@append_history
def check(self, energy):
self._itcount += 1
......@@ -290,17 +371,20 @@ class AbsDeltaEnergyController(IterationController):
def __init__(self, deltaE, convergence_level=1, iteration_limit=None,
name=None):
super(AbsDeltaEnergyController, self).__init__()
self._deltaE = deltaE
self._convergence_level = convergence_level
self._iteration_limit = iteration_limit
self._name = name
@append_history
def start(self, energy):
self._itcount = -1
self._ccount = 0
self._Eold = 0.
return self.check(energy)
@append_history
def check(self, energy):
self._itcount += 1
......
......@@ -16,14 +16,17 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import os
from datetime import datetime as dt
import numpy as np
from matplotlib.dates import DateFormatter, date2num
from .domains.gl_space import GLSpace
from .domains.hp_space import HPSpace
from .domains.power_space import PowerSpace
from .domains.rg_space import RGSpace
from .field import Field
from .minimization.iteration_controllers import EnergyHistory
# relevant properties:
# - x/y size
......@@ -261,6 +264,44 @@ def _register_cmaps():
plt.register_cmap(cmap=LinearSegmentedColormap("Plus Minus", pm_cmap))
def _plot_history(f, ax, **kwargs):
import matplotlib.pyplot as plt
for i, fld in enumerate(f):
if not isinstance(fld, EnergyHistory):
raise TypeError
label = kwargs.pop("label", None)
if not isinstance(label, list):
label = [label] * len(f)
alpha = kwargs.pop("alpha", None)
if not isinstance(alpha, list):
alpha = [alpha] * len(f)
color = kwargs.pop("color", None)
if not isinstance(color, list):
color = [color] * len(f)
size = kwargs.pop("s", None)
if not isinstance(size, list):
size = [size] * len(f)
ax.set_title(kwargs.pop("title", ""))
ax.set_xlabel(kwargs.pop("xlabel", ""))
ax.set_ylabel(kwargs.pop("ylabel", ""))
plt.xscale(kwargs.pop("xscale", "linear"))
plt.yscale(kwargs.pop("yscale", "linear"))
mi, ma = np.inf, -np.inf
for i, fld in enumerate(f):
xcoord = date2num([dt.fromtimestamp(ts) for ts in fld.time_stamps])
ycoord = fld.energy_values
ax.scatter(xcoord, ycoord, label=label[i], alpha=alpha[i],
color=color[i], s=size[i])
mi, ma = min([min(xcoord), mi]), max([max(xcoord), ma])
delta = (ma-mi)*0.05
ax.set_xlim((mi-delta, ma+delta))
xfmt = DateFormatter('%H:%M')
ax.xaxis.set_major_formatter(xfmt)
_limit_xy(**kwargs)
if label != ([None]*len(f)):
plt.legend()
def _plot1D(f, ax, **kwargs):
import matplotlib.pyplot as plt
......@@ -334,10 +375,10 @@ def _plot2D(f, ax, **kwargs):
x_space = 0
if len(dom) == 2:
f_space = kwargs.pop("freq_space_idx", 1)
if not f_space in [0, 1]:
if f_space not in [0, 1]:
raise ValueError("Invalid frequency space index")
if (not isinstance(dom[f_space], RGSpace)) \
or len(dom[f_space].shape) != 1:
or len(dom[f_space].shape) != 1:
raise TypeError("Need 1D RGSpace as frequency space domain")
x_space = 1 - f_space
......@@ -412,8 +453,8 @@ def _plot2D(f, ax, **kwargs):
if have_rgb:
plt.imshow(res, origin="lower")
else:
plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"), norm=norm.get('norm'),
cmap=cmap, origin="lower")
plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"),
norm=norm.get('norm'), cmap=cmap, origin="lower")
plt.colorbar(orientation="horizontal")
return
raise ValueError("Field type not(yet) supported")
......@@ -421,11 +462,14 @@ def _plot2D(f, ax, **kwargs):
def _plot(f, ax, **kwargs):
_register_cmaps()
if isinstance(f, Field):
if isinstance(f, Field) or isinstance(f, EnergyHistory):
f = [f]
f = list(f)
if len(f) == 0:
raise ValueError("need something to plot")
if isinstance(f[0], EnergyHistory):
_plot_history(f, ax, **kwargs)
return
if not isinstance(f[0], Field):
raise TypeError("incorrect data type")
dom1 = f[0].domain
......
......@@ -15,9 +15,11 @@
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import nifty6 as ift
from numpy.testing import assert_, assert_allclose
import pytest
from numpy.testing import assert_, assert_allclose
import nifty6 as ift
from .common import setup_function, teardown_function
pmp = pytest.mark.parametrize
......@@ -37,6 +39,7 @@ def test_kl(constants, point_estimates, mirror_samples, mf):
import numpy as np
lh = ift.GaussianEnergy(domain=op.target, sampling_dtype=np.float64) @ op
ic = ift.GradientNormController(iteration_limit=5)
ic.enable_logging()
h = ift.StandardHamiltonian(lh, ic_samp=ic)
mean0 = ift.from_random('normal', h.domain)
......@@ -48,6 +51,14 @@ def test_kl(constants, point_estimates, mirror_samples, mf):
point_estimates=point_estimates,
mirror_samples=mirror_samples,
napprox=0)
assert_(len(ic.history) > 0)
assert_(len(ic.history) == len(ic.history.time_stamps))
assert_(len(ic.history) == len(ic.history.energy_values))
ic.history.reset()
assert_(len(ic.history) == 0)
assert_(len(ic.history) == len(ic.history.time_stamps))
assert_(len(ic.history) == len(ic.history.energy_values))
locsamp = kl._local_samples
klpure = ift.MetricGaussianKL(mean0,
h,
......@@ -83,8 +94,10 @@ def test_kl(constants, point_estimates, mirror_samples, mf):
# Test constants (after some minimization)
cg = ift.GradientNormController(iteration_limit=5)
minimizer = ift.NewtonCG(cg)
minimizer = ift.NewtonCG(cg, activate_logging=True)
kl, _ = minimizer(kl)
if len(constants) != 2:
assert_(len(minimizer.inversion_history) > 0)
diff = (mean0 - kl.position).to_dict()
for kk in constants:
assert_allclose(diff[kk].val, 0*diff[kk].val)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment