Commit e9479053 authored by Lukas Platz's avatar Lukas Platz
Browse files

Merge branch 'NIFTy_6' into fix_docstring_warnings

parents add2880a a9fea605
Pipeline #75236 passed with stages
in 8 minutes and 24 seconds
...@@ -47,9 +47,9 @@ import numpy as np ...@@ -47,9 +47,9 @@ import numpy as np
dom = ift.UnstructuredDomain(5) dom = ift.UnstructuredDomain(5)
dtype = [np.float64, np.complex128][1] dtype = [np.float64, np.complex128][1]
invcov = ift.ScalingOperator(dom, 3) invcov = ift.ScalingOperator(dom, 3)
e = ift.GaussianEnergy(mean=ift.from_random('normal', dom, dtype=dtype), e = ift.GaussianEnergy(mean=ift.from_random(dom, 'normal', dtype=dtype),
inverse_covariance=invcov) inverse_covariance=invcov)
pos = ift.from_random('normal', dom, dtype=np.complex128) pos = ift.from_random(dom, 'normal', dtype=np.complex128)
lin = e(ift.Linearization.make_var(pos, want_metric=True)) lin = e(ift.Linearization.make_var(pos, want_metric=True))
met = lin.metric met = lin.metric
print(met) print(met)
...@@ -71,6 +71,13 @@ the generation of reproducible random numbers in the presence of MPI parallelism ...@@ -71,6 +71,13 @@ the generation of reproducible random numbers in the presence of MPI parallelism
and leads to cleaner code overall. Please see the documentation of and leads to cleaner code overall. Please see the documentation of
`nifty6.random` for details. `nifty6.random` for details.
Interface Change for from_random and OuterProduct
=================================================
The sugar.from_random, Field.from_random, MultiField.from_random now take domain
as the first argument and default to 'normal' for the second argument.
Likewise OuterProduct takes domain as the first argument and a field as the second.
Interface Change for non-linear Operators Interface Change for non-linear Operators
========================================= =========================================
......
...@@ -43,7 +43,7 @@ if __name__ == '__main__': ...@@ -43,7 +43,7 @@ if __name__ == '__main__':
harmonic_space = position_space.get_default_codomain() harmonic_space = position_space.get_default_codomain()
HT = ift.HarmonicTransformOperator(harmonic_space, position_space) HT = ift.HarmonicTransformOperator(harmonic_space, position_space)
position = ift.from_random('normal', harmonic_space) position = ift.from_random(harmonic_space, 'normal')
# Define power spectrum and amplitudes # Define power spectrum and amplitudes
def sqrtpspec(k): def sqrtpspec(k):
...@@ -58,13 +58,13 @@ if __name__ == '__main__': ...@@ -58,13 +58,13 @@ if __name__ == '__main__':
# Generate mock data # Generate mock data
p = R(sky) p = R(sky)
mock_position = ift.from_random('normal', harmonic_space) mock_position = ift.from_random(harmonic_space, 'normal')
tmp = p(mock_position).val.astype(np.float64) tmp = p(mock_position).val.astype(np.float64)
data = ift.random.current_rng().binomial(1, tmp) data = ift.random.current_rng().binomial(1, tmp)
data = ift.Field.from_raw(R.target, data) data = ift.Field.from_raw(R.target, data)
# Compute likelihood and Hamiltonian # Compute likelihood and Hamiltonian
position = ift.from_random('normal', harmonic_space) position = ift.from_random(harmonic_space, 'normal')
likelihood = ift.BernoulliEnergy(data) @ p likelihood = ift.BernoulliEnergy(data) @ p
ic_newton = ift.DeltaEnergyController( ic_newton = ift.DeltaEnergyController(
name='Newton', iteration_limit=100, tol_rel_deltaE=1e-8) name='Newton', iteration_limit=100, tol_rel_deltaE=1e-8)
......
...@@ -40,7 +40,7 @@ def make_checkerboard_mask(position_space): ...@@ -40,7 +40,7 @@ def make_checkerboard_mask(position_space):
def make_random_mask(): def make_random_mask():
# Random mask for spherical mode # Random mask for spherical mode
mask = ift.from_random('pm1', position_space) mask = ift.from_random(position_space, 'pm1')
mask = (mask + 1)/2 mask = (mask + 1)/2
return mask.val return mask.val
......
...@@ -90,7 +90,7 @@ if __name__ == '__main__': ...@@ -90,7 +90,7 @@ if __name__ == '__main__':
# Generate mock data and define likelihood operator # Generate mock data and define likelihood operator
d_space = R.target[0] d_space = R.target[0]
lamb = R(sky) lamb = R(sky)
mock_position = ift.from_random('normal', domain) mock_position = ift.from_random(domain, 'normal')
data = lamb(mock_position) data = lamb(mock_position)
data = ift.random.current_rng().poisson(data.val.astype(np.float64)) data = ift.random.current_rng().poisson(data.val.astype(np.float64))
data = ift.Field.from_raw(d_space, data) data = ift.Field.from_raw(d_space, data)
...@@ -103,7 +103,7 @@ if __name__ == '__main__': ...@@ -103,7 +103,7 @@ if __name__ == '__main__':
# Compute MAP solution by minimizing the information Hamiltonian # Compute MAP solution by minimizing the information Hamiltonian
H = ift.StandardHamiltonian(likelihood) H = ift.StandardHamiltonian(likelihood)
initial_position = ift.from_random('normal', domain) initial_position = ift.from_random(domain, 'normal')
H = ift.EnergyAdapter(initial_position, H, want_metric=True) H = ift.EnergyAdapter(initial_position, H, want_metric=True)
H, convergence = minimizer(H) H, convergence = minimizer(H)
......
...@@ -98,7 +98,7 @@ if __name__ == '__main__': ...@@ -98,7 +98,7 @@ if __name__ == '__main__':
N = ift.ScalingOperator(data_space, noise) N = ift.ScalingOperator(data_space, noise)
# Generate mock signal and data # Generate mock signal and data
mock_position = ift.from_random('normal', signal_response.domain) mock_position = ift.from_random(signal_response.domain, 'normal')
data = signal_response(mock_position) + N.draw_sample_with_dtype(dtype=np.float64) data = signal_response(mock_position) + N.draw_sample_with_dtype(dtype=np.float64)
# Minimization parameters # Minimization parameters
......
...@@ -97,7 +97,7 @@ if __name__ == '__main__': ...@@ -97,7 +97,7 @@ if __name__ == '__main__':
N = ift.ScalingOperator(data_space, noise) N = ift.ScalingOperator(data_space, noise)
# Generate mock signal and data # Generate mock signal and data
mock_position = ift.from_random('normal', signal_response.domain) mock_position = ift.from_random(signal_response.domain, 'normal')
data = signal_response(mock_position) + N.draw_sample_with_dtype(dtype=np.float64) data = signal_response(mock_position) + N.draw_sample_with_dtype(dtype=np.float64)
plot = ift.Plot() plot = ift.Plot()
...@@ -114,7 +114,9 @@ if __name__ == '__main__': ...@@ -114,7 +114,9 @@ if __name__ == '__main__':
ic_newton = ift.AbsDeltaEnergyController(name='Newton', ic_newton = ift.AbsDeltaEnergyController(name='Newton',
deltaE=0.01, deltaE=0.01,
iteration_limit=35) iteration_limit=35)
minimizer = ift.NewtonCG(ic_newton) ic_sampling.enable_logging()
ic_newton.enable_logging()
minimizer = ift.NewtonCG(ic_newton, activate_logging=True)
## number of samples used to estimate the KL ## number of samples used to estimate the KL
N_samples = 20 N_samples = 20
...@@ -143,10 +145,15 @@ if __name__ == '__main__': ...@@ -143,10 +145,15 @@ if __name__ == '__main__':
plot.add([A2.force(KL.position), plot.add([A2.force(KL.position),
A2.force(mock_position)], A2.force(mock_position)],
title="power2") title="power2")
plot.output(nx=2, plot.add((ic_newton.history, ic_sampling.history,
minimizer.inversion_history),
label=['KL', 'Sampling', 'Newton inversion'],
title='Cumulative energies', s=[None, None, 1],
alpha=[None, 0.2, None])
plot.output(nx=3,
ny=2, ny=2,
ysize=10, ysize=10,
xsize=10, xsize=15,
name=filename.format("loop_{:02d}".format(i))) name=filename.format("loop_{:02d}".format(i)))
# Done, draw posterior samples # Done, draw posterior samples
......
...@@ -26,7 +26,7 @@ from .operators.diagonal_operator import DiagonalOperator ...@@ -26,7 +26,7 @@ from .operators.diagonal_operator import DiagonalOperator
from .operators.distributors import DOFDistributor, PowerDistributor from .operators.distributors import DOFDistributor, PowerDistributor
from .operators.domain_tuple_field_inserter import DomainTupleFieldInserter from .operators.domain_tuple_field_inserter import DomainTupleFieldInserter
from .operators.einsum import LinearEinsum, MultiLinearEinsum from .operators.einsum import LinearEinsum, MultiLinearEinsum
from .operators.contraction_operator import ContractionOperator from .operators.contraction_operator import ContractionOperator, IntegrationOperator
from .operators.linear_interpolation import LinearInterpolator from .operators.linear_interpolation import LinearInterpolator
from .operators.endomorphic_operator import EndomorphicOperator from .operators.endomorphic_operator import EndomorphicOperator
from .operators.harmonic_operators import ( from .operators.harmonic_operators import (
......
...@@ -42,8 +42,8 @@ def _adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol, ...@@ -42,8 +42,8 @@ def _adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol,
needed_cap = op.TIMES | op.ADJOINT_TIMES needed_cap = op.TIMES | op.ADJOINT_TIMES
if (op.capability & needed_cap) != needed_cap: if (op.capability & needed_cap) != needed_cap:
return return
f1 = from_random("normal", op.domain, dtype=domain_dtype) f1 = from_random(op.domain, "normal", dtype=domain_dtype)
f2 = from_random("normal", op.target, dtype=target_dtype) f2 = from_random(op.target, "normal", dtype=target_dtype)
res1 = f1.s_vdot(op.adjoint_times(f2)) res1 = f1.s_vdot(op.adjoint_times(f2))
res2 = op.times(f1).s_vdot(f2) res2 = op.times(f1).s_vdot(f2)
if only_r_linear: if only_r_linear:
...@@ -55,11 +55,11 @@ def _inverse_implementation(op, domain_dtype, target_dtype, atol, rtol): ...@@ -55,11 +55,11 @@ def _inverse_implementation(op, domain_dtype, target_dtype, atol, rtol):
needed_cap = op.TIMES | op.INVERSE_TIMES needed_cap = op.TIMES | op.INVERSE_TIMES
if (op.capability & needed_cap) != needed_cap: if (op.capability & needed_cap) != needed_cap:
return return
foo = from_random("normal", op.target, dtype=target_dtype) foo = from_random(op.target, "normal", dtype=target_dtype)
res = op(op.inverse_times(foo)) res = op(op.inverse_times(foo))
assert_allclose(res, foo, atol=atol, rtol=rtol) assert_allclose(res, foo, atol=atol, rtol=rtol)
foo = from_random("normal", op.domain, dtype=domain_dtype) foo = from_random(op.domain, "normal", dtype=domain_dtype)
res = op.inverse_times(op(foo)) res = op.inverse_times(op(foo))
assert_allclose(res, foo, atol=atol, rtol=rtol) assert_allclose(res, foo, atol=atol, rtol=rtol)
...@@ -75,8 +75,8 @@ def _check_linearity(op, domain_dtype, atol, rtol): ...@@ -75,8 +75,8 @@ def _check_linearity(op, domain_dtype, atol, rtol):
needed_cap = op.TIMES needed_cap = op.TIMES
if (op.capability & needed_cap) != needed_cap: if (op.capability & needed_cap) != needed_cap:
return return
fld1 = from_random("normal", op.domain, dtype=domain_dtype) fld1 = from_random(op.domain, "normal", dtype=domain_dtype)
fld2 = from_random("normal", op.domain, dtype=domain_dtype) fld2 = from_random(op.domain, "normal", dtype=domain_dtype)
alpha = np.random.random() # FIXME: this can break badly with MPI! alpha = np.random.random() # FIXME: this can break badly with MPI!
val1 = op(alpha*fld1+fld2) val1 = op(alpha*fld1+fld2)
val2 = alpha*op(fld1)+op(fld2) val2 = alpha*op(fld1)+op(fld2)
...@@ -88,7 +88,7 @@ def _actual_domain_check_linear(op, domain_dtype=None, inp=None): ...@@ -88,7 +88,7 @@ def _actual_domain_check_linear(op, domain_dtype=None, inp=None):
if (op.capability & needed_cap) != needed_cap: if (op.capability & needed_cap) != needed_cap:
return return
if domain_dtype is not None: if domain_dtype is not None:
inp = from_random("normal", op.domain, dtype=domain_dtype) inp = from_random(op.domain, "normal", dtype=domain_dtype)
elif inp is None: elif inp is None:
raise ValueError('Need to specify either dtype or inp') raise ValueError('Need to specify either dtype or inp')
assert_(inp.domain is op.domain) assert_(inp.domain is op.domain)
...@@ -219,7 +219,7 @@ def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64, ...@@ -219,7 +219,7 @@ def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64,
def _get_acceptable_location(op, loc, lin): def _get_acceptable_location(op, loc, lin):
if not np.isfinite(lin.val.s_sum()): if not np.isfinite(lin.val.s_sum()):
raise ValueError('Initial value must be finite') raise ValueError('Initial value must be finite')
dir = from_random("normal", loc.domain) dir = from_random(loc.domain, "normal")
dirder = lin.jac(dir) dirder = lin.jac(dir)
if dirder.norm() == 0: if dirder.norm() == 0:
dir = dir * (lin.val.norm()*1e-5) dir = dir * (lin.val.norm()*1e-5)
......
...@@ -124,7 +124,7 @@ class Field(Operator): ...@@ -124,7 +124,7 @@ class Field(Operator):
return Field(DomainTuple.make(new_domain), self._val) return Field(DomainTuple.make(new_domain), self._val)
@staticmethod @staticmethod
def from_random(random_type, domain, dtype=np.float64, **kwargs): def from_random(domain, random_type='normal', dtype=np.float64, **kwargs):
"""Draws a random field with the given parameters. """Draws a random field with the given parameters.
Parameters Parameters
...@@ -283,7 +283,7 @@ class Field(Operator): ...@@ -283,7 +283,7 @@ class Field(Operator):
raise TypeError("The multiplier must be an instance of " + raise TypeError("The multiplier must be an instance of " +
"the Field class") "the Field class")
from .operators.outer_product_operator import OuterProduct from .operators.outer_product_operator import OuterProduct
return OuterProduct(self, x.domain)(x) return OuterProduct(x.domain, self)(x)
def vdot(self, x, spaces=None): def vdot(self, x, spaces=None):
"""Computes the dot product of 'self' with x. """Computes the dot product of 'self' with x.
......
...@@ -524,7 +524,7 @@ class CorrelatedFieldMaker: ...@@ -524,7 +524,7 @@ class CorrelatedFieldMaker:
for kk, op in lst: for kk, op in lst:
sc = StatCalculator() sc = StatCalculator()
for _ in range(prior_info): for _ in range(prior_info):
sc.add(op(from_random('normal', op.domain))) sc.add(op(from_random(op.domain, 'normal')))
mean = sc.mean.val mean = sc.mean.val
stddev = sc.var.ptw("sqrt").val stddev = sc.var.ptw("sqrt").val
for m, s in zip(mean.flatten(), stddev.flatten()): for m, s in zip(mean.flatten(), stddev.flatten()):
...@@ -539,7 +539,7 @@ class CorrelatedFieldMaker: ...@@ -539,7 +539,7 @@ class CorrelatedFieldMaker:
scm = 1. scm = 1.
for a in self._a: for a in self._a:
op = a.fluctuation_amplitude*self._azm.ptw("reciprocal") op = a.fluctuation_amplitude*self._azm.ptw("reciprocal")
res = np.array([op(from_random('normal', op.domain)).val res = np.array([op(from_random(op.domain, 'normal')).val
for _ in range(nsamples)]) for _ in range(nsamples)])
scm *= res**2 + 1. scm *= res**2 + 1.
return fluctuations_slice_mean/np.mean(np.sqrt(scm)) return fluctuations_slice_mean/np.mean(np.sqrt(scm))
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
# #
# Copyright(C) 2013-2019 Max-Planck-Society # Copyright(C) 2013-2020 Max-Planck-Society
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
...@@ -207,12 +207,13 @@ class Linearization(Operator): ...@@ -207,12 +207,13 @@ class Linearization(Operator):
return self.__mul__(other) return self.__mul__(other)
from .operators.outer_product_operator import OuterProduct from .operators.outer_product_operator import OuterProduct
if other.jac is None: if other.jac is None:
return self.new(OuterProduct(self._val, other.domain)(other), return self.new(OuterProduct(other.domain, self._val)(other),
OuterProduct(self._jac(self._val), other.domain)) OuterProduct(other.domain, self._jac(self._val)))
tmp_op = OuterProduct(other.target, self._val)
return self.new( return self.new(
OuterProduct(self._val, other.target)(other._val), tmp_op(other._val),
OuterProduct(self._jac(self._val), other.target)._myadd( OuterProduct(other.target, self._jac(self._val))._myadd(
OuterProduct(self._val, other.target)(other._jac), False)) tmp_op(other._jac), False))
def vdot(self, other): def vdot(self, other):
"""Computes the inner product of this Linearization with a Field or """Computes the inner product of this Linearization with a Field or
...@@ -270,10 +271,8 @@ class Linearization(Operator): ...@@ -270,10 +271,8 @@ class Linearization(Operator):
Linearization Linearization
the (partial) integral the (partial) integral
""" """
from .operators.contraction_operator import ContractionOperator from .operators.contraction_operator import IntegrationOperator
return self.new( return IntegrationOperator(self._target, spaces)(self)
self._val.integrate(spaces),
ContractionOperator(self._jac.target, spaces, 1)(self._jac))
def ptw(self, op, *args, **kwargs): def ptw(self, op, *args, **kwargs):
t1, t2 = self._val.ptw_with_deriv(op, *args, **kwargs) t1, t2 = self._val.ptw_with_deriv(op, *args, **kwargs)
......
...@@ -166,7 +166,8 @@ class NewtonCG(DescentMinimizer): ...@@ -166,7 +166,8 @@ class NewtonCG(DescentMinimizer):
""" """
def __init__(self, controller, napprox=0, line_searcher=None, name=None, def __init__(self, controller, napprox=0, line_searcher=None, name=None,
nreset=20, max_cg_iterations=200, energy_reduction_factor=0.1): nreset=20, max_cg_iterations=200, energy_reduction_factor=0.1,
activate_logging=False):
if line_searcher is None: if line_searcher is None:
line_searcher = LineSearch(preferred_initial_step_size=1.) line_searcher = LineSearch(preferred_initial_step_size=1.)
super(NewtonCG, self).__init__(controller=controller, super(NewtonCG, self).__init__(controller=controller,
...@@ -176,6 +177,8 @@ class NewtonCG(DescentMinimizer): ...@@ -176,6 +177,8 @@ class NewtonCG(DescentMinimizer):
self._nreset = nreset self._nreset = nreset
self._max_cg_iterations = max_cg_iterations self._max_cg_iterations = max_cg_iterations
self._alpha = energy_reduction_factor self._alpha = energy_reduction_factor
from .iteration_controllers import EnergyHistory
self._history = EnergyHistory() if activate_logging else None
def get_descent_direction(self, energy, old_value=None): def get_descent_direction(self, energy, old_value=None):
if old_value is None: if old_value is None:
...@@ -184,14 +187,22 @@ class NewtonCG(DescentMinimizer): ...@@ -184,14 +187,22 @@ class NewtonCG(DescentMinimizer):
ediff = self._alpha*(old_value-energy.value) ediff = self._alpha*(old_value-energy.value)
ic = AbsDeltaEnergyController( ic = AbsDeltaEnergyController(
ediff, iteration_limit=self._max_cg_iterations, name=self._name) ediff, iteration_limit=self._max_cg_iterations, name=self._name)
if self._history is not None:
ic.enable_logging()
e = QuadraticEnergy(0*energy.position, energy.metric, energy.gradient) e = QuadraticEnergy(0*energy.position, energy.metric, energy.gradient)
p = None p = None
if self._napprox > 1: if self._napprox > 1:
met = energy.metric met = energy.metric
p = makeOp(approximation2endo(met, self._napprox)).inverse p = makeOp(approximation2endo(met, self._napprox)).inverse
e, conv = ConjugateGradient(ic, nreset=self._nreset)(e, p) e, conv = ConjugateGradient(ic, nreset=self._nreset)(e, p)
if self._history is not None:
self._history += ic.history
return -e.position return -e.position
@property
def inversion_history(self):
return self._history
class L_BFGS(DescentMinimizer): class L_BFGS(DescentMinimizer):
def __init__(self, controller, line_searcher=LineSearch(), def __init__(self, controller, line_searcher=LineSearch(),
......
...@@ -11,10 +11,13 @@ ...@@ -11,10 +11,13 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
# #
# Copyright(C) 2013-2019 Max-Planck-Society # Copyright(C) 2013-2020 Max-Planck-Society
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import functools
from time import time
import numpy as np import numpy as np
from ..logger import logger from ..logger import logger
...@@ -37,10 +40,17 @@ class IterationController(metaclass=NiftyMeta): ...@@ -37,10 +40,17 @@ class IterationController(metaclass=NiftyMeta):
class; the implementer has full flexibility to use whichever criteria are class; the implementer has full flexibility to use whichever criteria are
appropriate for a particular problem - as long as they can be computed from appropriate for a particular problem - as long as they can be computed from
the information passed to the controller during the iteration process. the information passed to the controller during the iteration process.
For analyzing minimization procedures IterationControllers can log energy
values together with the respective time stamps. In order to activate this
feature `activate_logging()` needs to be called.
""" """
CONVERGED, CONTINUE, ERROR = list(range(3)) CONVERGED, CONTINUE, ERROR = list(range(3))
def __init__(self):
self._history = None
def start(self, energy): def start(self, energy):
"""Starts the iteration. """Starts the iteration.
...@@ -69,6 +79,68 @@ class IterationController(metaclass=NiftyMeta): ...@@ -69,6 +79,68 @@ class IterationController(metaclass=NiftyMeta):
""" """
raise NotImplementedError raise NotImplementedError
def enable_logging(self):
"""Enables the logging functionality. If the log has been populated
before, it stays as it is."""
if self._history is None:
self._history = EnergyHistory()
def disable_logging(self):
"""Disables the logging functionality. If the log has been populated
before, it is dropped."""
self._history = None
@property
def history(self):
return self._history
class EnergyHistory(object):
def __init__(self):
self._lst = []
def append(self, x):
if len(x) != 2:
raise ValueError
self._lst.append((float(x[0]), float(x[1])))
def reset(self):
self._lst = []
@property
def time_stamps(self):
return [x for x, _ in self._lst]
@property
def energy_values(self):
return [x for _, x in self._lst]
def __add__(self, other):
if not isinstance(other, EnergyHistory):
return NotImplemented
res = EnergyHistory()
res._lst = self._lst + other._lst
return res
def __iadd__(self, other):
if not isinstance(other, EnergyHistory):
return NotImplemented
self._lst += other._lst
return self
def __len__(self):
return len(self._lst)
def append_history(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
hist = args[0].history
if isinstance(hist, EnergyHistory):
hist.append((time(), args[1].value))
return func(*args, **kwargs)
return wrapper
class GradientNormController(IterationController): class GradientNormController(IterationController):
"""An iteration controller checking (mainly) the L2 gradient norm. """An iteration controller checking (mainly) the L2 gradient norm.
...@@ -94,12 +166,14 @@ class GradientNormController(IterationController): ...@@ -94,12 +166,14 @@ class GradientNormController(IterationController):
def __init__(self, tol_abs_gradnorm=None, tol_rel_gradnorm=None, def __init__(self, tol_abs_gradnorm=None, tol_rel_gradnorm=None,
convergence_level=1, iteration_limit=None, name=None): convergence_level=1, iteration_limit=None, name=None):
super(GradientNormController, self).__init__()
self._tol_abs_gradnorm = tol_abs_gradnorm self._tol_abs_gradnorm = tol_abs_gradnorm
self._tol_rel_gradnorm = tol_rel_gradnorm self._tol_rel_gradnorm = tol_rel_gradnorm
self._convergence_level = convergence_level self._convergence_level = convergence_level
self._iteration_limit = iteration_limit self._iteration_limit = iteration_limit
self._name = name self._name = name