Commit 5aa8a3b6 authored by Jakob Knollmueller's avatar Jakob Knollmueller
Browse files

Merge branch 'NIFTy_5' into demos

parents 88eaae59 cb66a789
from .operator_tests import consistency_check from .operator_tests import consistency_check
from .energy_tests import * from .energy_and_model_tests import *
...@@ -18,19 +18,42 @@ ...@@ -18,19 +18,42 @@
import numpy as np import numpy as np
from ..sugar import from_random from ..sugar import from_random
from ..minimization.energy import Energy
from ..models.model import Model
__all__ = ["check_value_gradient_consistency", __all__ = ["check_value_gradient_consistency",
"check_value_gradient_curvature_consistency"] "check_value_gradient_curvature_consistency"]
def _get_acceptable_model(M):
val = M.value
if not np.isfinite(val.sum()):
raise ValueError('Initial Model value must be finite')
dir = from_random("normal", M.position.domain)
dirder = M.gradient(dir)
dir *= val/(dirder).norm()*1e-5
# Find a step length that leads to a "reasonable" Model
for i in range(50):
try:
M2 = M.at(M.position+dir)
if np.isfinite(M2.value.sum()) and abs(M2.value.sum()) < 1e20:
break
except FloatingPointError:
pass
dir *= 0.5
else:
raise ValueError("could not find a reasonable initial step")
return M2
def _get_acceptable_energy(E): def _get_acceptable_energy(E):
val = E.value val = E.value
if not np.isfinite(val): if not np.isfinite(val):
raise ValueError raise ValueError('Initial Energy must be finite')
dir = from_random("normal", E.position.domain) dir = from_random("normal", E.position.domain)
dirder = E.gradient.vdot(dir) dirder = E.gradient.vdot(dir)
dir *= np.abs(val)/np.abs(dirder)*1e-5 dir *= np.abs(val)/np.abs(dirder)*1e-5
# find a step length that leads to a "reasonable" energy # Find a step length that leads to a "reasonable" energy
for i in range(50): for i in range(50):
try: try:
E2 = E.at(E.position+dir) E2 = E.at(E.position+dir)
...@@ -46,31 +69,45 @@ def _get_acceptable_energy(E): ...@@ -46,31 +69,45 @@ def _get_acceptable_energy(E):
def check_value_gradient_consistency(E, tol=1e-8, ntries=100): def check_value_gradient_consistency(E, tol=1e-8, ntries=100):
for _ in range(ntries): for _ in range(ntries):
if isinstance(E, Energy):
E2 = _get_acceptable_energy(E) E2 = _get_acceptable_energy(E)
else:
E2 = _get_acceptable_model(E)
val = E.value val = E.value
dir = E2.position - E.position dir = E2.position - E.position
# Enext = E2 Enext = E2
dirnorm = dir.norm() dirnorm = dir.norm()
for i in range(50): for i in range(50):
Emid = E.at(E.position + 0.5*dir) Emid = E.at(E.position + 0.5*dir)
if isinstance(E, Energy):
dirder = Emid.gradient.vdot(dir)/dirnorm dirder = Emid.gradient.vdot(dir)/dirnorm
else:
dirder = Emid.gradient(dir)/dirnorm
numgrad = (E2.value-val)/dirnorm
if isinstance(E, Model):
xtol = tol * dirder.norm() / np.sqrt(dirder.size)
if (abs(numgrad-dirder) < xtol).all():
break
else:
xtol = tol*Emid.gradient_norm xtol = tol*Emid.gradient_norm
if abs((E2.value-val)/dirnorm - dirder) < xtol: if abs(numgrad-dirder) < xtol:
break break
dir *= 0.5 dir *= 0.5
dirnorm *= 0.5 dirnorm *= 0.5
E2 = Emid E2 = Emid
else: else:
raise ValueError("gradient and value seem inconsistent") raise ValueError("gradient and value seem inconsistent")
# E = Enext E = Enext
def check_value_gradient_curvature_consistency(E, tol=1e-8, ntries=100): def check_value_gradient_curvature_consistency(E, tol=1e-8, ntries=100):
if isinstance(E, Model):
raise ValueError('Models have no curvature, thus it cannot be tested.')
for _ in range(ntries): for _ in range(ntries):
E2 = _get_acceptable_energy(E) E2 = _get_acceptable_energy(E)
val = E.value val = E.value
dir = E2.position - E.position dir = E2.position - E.position
# Enext = E2 Enext = E2
dirnorm = dir.norm() dirnorm = dir.norm()
for i in range(50): for i in range(50):
Emid = E.at(E.position + 0.5*dir) Emid = E.at(E.position + 0.5*dir)
...@@ -85,4 +122,4 @@ def check_value_gradient_curvature_consistency(E, tol=1e-8, ntries=100): ...@@ -85,4 +122,4 @@ def check_value_gradient_curvature_consistency(E, tol=1e-8, ntries=100):
E2 = Emid E2 = Emid
else: else:
raise ValueError("gradient, value and curvature seem inconsistent") raise ValueError("gradient, value and curvature seem inconsistent")
# E = Enext E = Enext
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
import numpy as np
from ..field import Field from ..field import Field
from ..multi import MultiField from ..multi import MultiField
from ..utilities import NiftyMetaBase, memo from ..utilities import NiftyMetaBase, memo
...@@ -128,30 +129,27 @@ class Energy(NiftyMetaBase()): ...@@ -128,30 +129,27 @@ class Energy(NiftyMetaBase()):
""" """
return None return None
def __mul__(self, factor):
from .energy_sum import EnergySum
if isinstance(factor, (float, int)):
return EnergySum.make([self], [factor])
return NotImplemented
def __rmul__(self, factor):
return self.__mul__(factor)
def __add__(self, other): def __add__(self, other):
if not isinstance(other, Energy): from .energy_sum import EnergySum
raise TypeError if isinstance(other, Energy):
return Add(self, other) return EnergySum.make([self, other])
return NotImplemented
def __sub__(self, other): def __sub__(self, other):
if not isinstance(other, Energy):
raise TypeError
return Add(self, (-1) * other)
def Add(energy1, energy2):
if (isinstance(energy1.position, MultiField) and
isinstance(energy2.position, MultiField)):
a = energy1.position._val
b = energy2.position._val
# Note: In python >3.5 one could do {**a, **b}
ab = a.copy()
ab.update(b)
position = MultiField(ab)
elif (isinstance(energy1.position, Field) and
isinstance(energy2.position, Field)):
position = energy1.position
else:
raise TypeError
from .energy_sum import EnergySum from .energy_sum import EnergySum
return EnergySum(position, [energy1, energy2]) if isinstance(other, Energy):
return EnergySum.make([self, other], [1., -1.])
return NotImplemented
def __neg__(self):
from .energy_sum import EnergySum
return EnergySum.make([self], [-1.])
...@@ -16,49 +16,65 @@ ...@@ -16,49 +16,65 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from .energy import Energy
from ..utilities import memo from ..utilities import memo
from .energy import Energy
class EnergySum(Energy): class EnergySum(Energy):
def __init__(self, position, energies, minimizer_controller=None, def __init__(self, position, energies, factors):
preconditioner=None, precon_idx=None):
super(EnergySum, self).__init__(position=position) super(EnergySum, self).__init__(position=position)
self._energies = [energy.at(position) for energy in energies] self._energies = tuple(e.at(position) for e in energies)
self._min_controller = minimizer_controller self._factors = tuple(factors)
self._preconditioner = preconditioner
self._precon_idx = precon_idx @staticmethod
def make(energies, factors=None):
if factors is None:
factors = (1,)*len(energies)
# unpack energies
eout = []
fout = []
EnergySum._unpackEnergies(energies, factors, 1., eout, fout)
for e in eout[1:]:
if not e.position.isEquivalentTo(eout[0].position):
raise ValueError("position mismatch")
return EnergySum(eout[0].position, eout, fout)
@staticmethod
def _unpackEnergies(e_in, f_in, prefactor, e_out, f_out):
for e, f in zip(e_in, f_in):
if isinstance(e, EnergySum):
EnergySum._unpackEnergies(e._energies, e._factors, prefactor*f,
e_out, f_out)
else:
e_out.append(e)
f_out.append(prefactor*f)
def at(self, position): def at(self, position):
return self.__class__(position, self._energies, self._min_controller, return self.__class__(position, self._energies, self._factors)
self._preconditioner, self._precon_idx)
@property @property
@memo @memo
def value(self): def value(self):
res = self._energies[0].value res = self._energies[0].value * self._factors[0]
for e in self._energies[1:]: for e, f in zip(self._energies[1:], self._factors[1:]):
res += e.value res += e.value * f
return res return res
@property @property
@memo @memo
def gradient(self): def gradient(self):
res = self._energies[0].gradient.copy() res = self._energies[0].gradient.copy() if self._factors[0] == 1. \
for e in self._energies[1:]: else self._energies[0].gradient * self._factors[0]
res += e.gradient
for e, f in zip(self._energies[1:], self._factors[1:]):
res += e.gradient if f == 1. else f*e.gradient
return res.lock() return res.lock()
@property @property
@memo @memo
def curvature(self): def curvature(self):
res = self._energies[0].curvature res = self._energies[0].curvature if self._factors[0] == 1. \
for e in self._energies[1:]: else self._energies[0].curvature * self._factors[0]
res = res + e.curvature for e, f in zip(self._energies[1:], self._factors[1:]):
if self._min_controller is None: res = res + (e.curvature if f == 1. else e.curvature*f)
return res return res
precon = self._preconditioner
if precon is None and self._precon_idx is not None:
precon = self._energies[self._precon_idx].curvature
from ..operators.inversion_enabler import InversionEnabler
return InversionEnabler(res, self._min_controller, precon)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import unittest
from itertools import product
from test.common import expand
import nifty5 as ift
import numpy as np
from numpy.testing import assert_allclose, assert_raises
class EnergySum_Tests(unittest.TestCase):
@expand(product([ift.GLSpace(15),
ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789),
[ift.RGSpace(3), ift.RGSpace(12)],
{'a': ift.GLSpace(15), 'b': ift.RGSpace(64)}],
[4, 78, 23]))
def testSum(self, domain, seed):
np.random.seed(seed)
pos = ift.from_random("normal", domain)
A1 = ift.makeOp(ift.from_random("normal", domain))
b1 = ift.from_random("normal", domain)
E1 = ift.QuadraticEnergy(pos, A1, b1)
A2 = ift.makeOp(ift.from_random("normal", domain))
b2 = ift.from_random("normal", domain)
E2 = ift.QuadraticEnergy(pos, A2, b2)
assert_allclose((E1+E2).value, E1.value+E2.value)
assert_allclose((E1-E2).value, E1.value-E2.value)
assert_allclose((3.8*E1+E2*4).value, 3.8*E1.value+4*E2.value)
assert_allclose((-E1).value, -(E1.value))
with assert_raises(TypeError):
E1*E2
with assert_raises(TypeError):
E1*2j
with assert_raises(TypeError):
E1+2
with assert_raises(TypeError):
E1-"hello"
with assert_raises(ValueError):
E1+E2.at(2*pos)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import unittest
from itertools import product
from test.common import expand
import nifty5 as ift
import numpy as np
class Model_Tests(unittest.TestCase):
@expand(product([ift.GLSpace(15),
ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)],
[4, 78, 23]))
def testMul(self, space, seed):
np.random.seed(seed)
S = ift.ScalingOperator(1., space)
s1 = S.draw_sample()
s2 = S.draw_sample()
s1_var = ift.Variable(ift.MultiField({'s1': s1}))['s1']
s2_var = ift.Variable(ift.MultiField({'s2': s2}))['s2']
ift.extra.check_value_gradient_consistency(s1_var*s2_var)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment