Commit 5aa8a3b6 authored by Jakob Knollmueller's avatar Jakob Knollmueller
Browse files

Merge branch 'NIFTy_5' into demos

parents 88eaae59 cb66a789
from .operator_tests import consistency_check
from .energy_tests import *
from .energy_and_model_tests import *
......@@ -18,19 +18,42 @@
import numpy as np
from ..sugar import from_random
from ..minimization.energy import Energy
from ..models.model import Model
__all__ = ["check_value_gradient_consistency",
"check_value_gradient_curvature_consistency"]
def _get_acceptable_model(M):
val = M.value
if not np.isfinite(val.sum()):
raise ValueError('Initial Model value must be finite')
dir = from_random("normal", M.position.domain)
dirder = M.gradient(dir)
dir *= val/(dirder).norm()*1e-5
# Find a step length that leads to a "reasonable" Model
for i in range(50):
try:
M2 = M.at(M.position+dir)
if np.isfinite(M2.value.sum()) and abs(M2.value.sum()) < 1e20:
break
except FloatingPointError:
pass
dir *= 0.5
else:
raise ValueError("could not find a reasonable initial step")
return M2
def _get_acceptable_energy(E):
val = E.value
if not np.isfinite(val):
raise ValueError
raise ValueError('Initial Energy must be finite')
dir = from_random("normal", E.position.domain)
dirder = E.gradient.vdot(dir)
dir *= np.abs(val)/np.abs(dirder)*1e-5
# find a step length that leads to a "reasonable" energy
# Find a step length that leads to a "reasonable" energy
for i in range(50):
try:
E2 = E.at(E.position+dir)
......@@ -46,31 +69,45 @@ def _get_acceptable_energy(E):
def check_value_gradient_consistency(E, tol=1e-8, ntries=100):
for _ in range(ntries):
E2 = _get_acceptable_energy(E)
if isinstance(E, Energy):
E2 = _get_acceptable_energy(E)
else:
E2 = _get_acceptable_model(E)
val = E.value
dir = E2.position - E.position
# Enext = E2
Enext = E2
dirnorm = dir.norm()
for i in range(50):
Emid = E.at(E.position + 0.5*dir)
dirder = Emid.gradient.vdot(dir)/dirnorm
xtol = tol*Emid.gradient_norm
if abs((E2.value-val)/dirnorm - dirder) < xtol:
break
if isinstance(E, Energy):
dirder = Emid.gradient.vdot(dir)/dirnorm
else:
dirder = Emid.gradient(dir)/dirnorm
numgrad = (E2.value-val)/dirnorm
if isinstance(E, Model):
xtol = tol * dirder.norm() / np.sqrt(dirder.size)
if (abs(numgrad-dirder) < xtol).all():
break
else:
xtol = tol*Emid.gradient_norm
if abs(numgrad-dirder) < xtol:
break
dir *= 0.5
dirnorm *= 0.5
E2 = Emid
else:
raise ValueError("gradient and value seem inconsistent")
# E = Enext
E = Enext
def check_value_gradient_curvature_consistency(E, tol=1e-8, ntries=100):
if isinstance(E, Model):
raise ValueError('Models have no curvature, thus it cannot be tested.')
for _ in range(ntries):
E2 = _get_acceptable_energy(E)
val = E.value
dir = E2.position - E.position
# Enext = E2
Enext = E2
dirnorm = dir.norm()
for i in range(50):
Emid = E.at(E.position + 0.5*dir)
......@@ -85,4 +122,4 @@ def check_value_gradient_curvature_consistency(E, tol=1e-8, ntries=100):
E2 = Emid
else:
raise ValueError("gradient, value and curvature seem inconsistent")
# E = Enext
E = Enext
......@@ -16,6 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import numpy as np
from ..field import Field
from ..multi import MultiField
from ..utilities import NiftyMetaBase, memo
......@@ -128,30 +129,27 @@ class Energy(NiftyMetaBase()):
"""
return None
def __mul__(self, factor):
from .energy_sum import EnergySum
if isinstance(factor, (float, int)):
return EnergySum.make([self], [factor])
return NotImplemented
def __rmul__(self, factor):
return self.__mul__(factor)
def __add__(self, other):
if not isinstance(other, Energy):
raise TypeError
return Add(self, other)
from .energy_sum import EnergySum
if isinstance(other, Energy):
return EnergySum.make([self, other])
return NotImplemented
def __sub__(self, other):
if not isinstance(other, Energy):
raise TypeError
return Add(self, (-1) * other)
def Add(energy1, energy2):
if (isinstance(energy1.position, MultiField) and
isinstance(energy2.position, MultiField)):
a = energy1.position._val
b = energy2.position._val
# Note: In python >3.5 one could do {**a, **b}
ab = a.copy()
ab.update(b)
position = MultiField(ab)
elif (isinstance(energy1.position, Field) and
isinstance(energy2.position, Field)):
position = energy1.position
else:
raise TypeError
from .energy_sum import EnergySum
return EnergySum(position, [energy1, energy2])
from .energy_sum import EnergySum
if isinstance(other, Energy):
return EnergySum.make([self, other], [1., -1.])
return NotImplemented
def __neg__(self):
from .energy_sum import EnergySum
return EnergySum.make([self], [-1.])
......@@ -16,49 +16,65 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from .energy import Energy
from ..utilities import memo
from .energy import Energy
class EnergySum(Energy):
def __init__(self, position, energies, minimizer_controller=None,
preconditioner=None, precon_idx=None):
def __init__(self, position, energies, factors):
super(EnergySum, self).__init__(position=position)
self._energies = [energy.at(position) for energy in energies]
self._min_controller = minimizer_controller
self._preconditioner = preconditioner
self._precon_idx = precon_idx
self._energies = tuple(e.at(position) for e in energies)
self._factors = tuple(factors)
@staticmethod
def make(energies, factors=None):
if factors is None:
factors = (1,)*len(energies)
# unpack energies
eout = []
fout = []
EnergySum._unpackEnergies(energies, factors, 1., eout, fout)
for e in eout[1:]:
if not e.position.isEquivalentTo(eout[0].position):
raise ValueError("position mismatch")
return EnergySum(eout[0].position, eout, fout)
@staticmethod
def _unpackEnergies(e_in, f_in, prefactor, e_out, f_out):
for e, f in zip(e_in, f_in):
if isinstance(e, EnergySum):
EnergySum._unpackEnergies(e._energies, e._factors, prefactor*f,
e_out, f_out)
else:
e_out.append(e)
f_out.append(prefactor*f)
def at(self, position):
return self.__class__(position, self._energies, self._min_controller,
self._preconditioner, self._precon_idx)
return self.__class__(position, self._energies, self._factors)
@property
@memo
def value(self):
res = self._energies[0].value
for e in self._energies[1:]:
res += e.value
res = self._energies[0].value * self._factors[0]
for e, f in zip(self._energies[1:], self._factors[1:]):
res += e.value * f
return res
@property
@memo
def gradient(self):
res = self._energies[0].gradient.copy()
for e in self._energies[1:]:
res += e.gradient
res = self._energies[0].gradient.copy() if self._factors[0] == 1. \
else self._energies[0].gradient * self._factors[0]
for e, f in zip(self._energies[1:], self._factors[1:]):
res += e.gradient if f == 1. else f*e.gradient
return res.lock()
@property
@memo
def curvature(self):
res = self._energies[0].curvature
for e in self._energies[1:]:
res = res + e.curvature
if self._min_controller is None:
return res
precon = self._preconditioner
if precon is None and self._precon_idx is not None:
precon = self._energies[self._precon_idx].curvature
from ..operators.inversion_enabler import InversionEnabler
return InversionEnabler(res, self._min_controller, precon)
res = self._energies[0].curvature if self._factors[0] == 1. \
else self._energies[0].curvature * self._factors[0]
for e, f in zip(self._energies[1:], self._factors[1:]):
res = res + (e.curvature if f == 1. else e.curvature*f)
return res
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import unittest
from itertools import product
from test.common import expand
import nifty5 as ift
import numpy as np
from numpy.testing import assert_allclose, assert_raises
class EnergySum_Tests(unittest.TestCase):
@expand(product([ift.GLSpace(15),
ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789),
[ift.RGSpace(3), ift.RGSpace(12)],
{'a': ift.GLSpace(15), 'b': ift.RGSpace(64)}],
[4, 78, 23]))
def testSum(self, domain, seed):
np.random.seed(seed)
pos = ift.from_random("normal", domain)
A1 = ift.makeOp(ift.from_random("normal", domain))
b1 = ift.from_random("normal", domain)
E1 = ift.QuadraticEnergy(pos, A1, b1)
A2 = ift.makeOp(ift.from_random("normal", domain))
b2 = ift.from_random("normal", domain)
E2 = ift.QuadraticEnergy(pos, A2, b2)
assert_allclose((E1+E2).value, E1.value+E2.value)
assert_allclose((E1-E2).value, E1.value-E2.value)
assert_allclose((3.8*E1+E2*4).value, 3.8*E1.value+4*E2.value)
assert_allclose((-E1).value, -(E1.value))
with assert_raises(TypeError):
E1*E2
with assert_raises(TypeError):
E1*2j
with assert_raises(TypeError):
E1+2
with assert_raises(TypeError):
E1-"hello"
with assert_raises(ValueError):
E1+E2.at(2*pos)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import unittest
from itertools import product
from test.common import expand
import nifty5 as ift
import numpy as np
class Model_Tests(unittest.TestCase):
@expand(product([ift.GLSpace(15),
ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)],
[4, 78, 23]))
def testMul(self, space, seed):
np.random.seed(seed)
S = ift.ScalingOperator(1., space)
s1 = S.draw_sample()
s2 = S.draw_sample()
s1_var = ift.Variable(ift.MultiField({'s1': s1}))['s1']
s2_var = ift.Variable(ift.MultiField({'s2': s2}))['s2']
ift.extra.check_value_gradient_consistency(s1_var*s2_var)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment