Commit 78b77518 authored by Martin Reinecke's avatar Martin Reinecke

merge NIFTY_4

parents 0e8e4be1 82c170d1
Pipeline #30101 passed with stages
in 1 minute and 25 seconds
...@@ -687,7 +687,7 @@ ...@@ -687,7 +687,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"precise = (np.abs(s_data-m_data) < uncertainty )\n", "precise = (np.abs(s_data-m_data) < uncertainty)\n",
"print(\"Error within uncertainty map bounds: \" + str(np.sum(precise) * 100 / N_pixels**2) + \"%\")\n", "print(\"Error within uncertainty map bounds: \" + str(np.sum(precise) * 100 / N_pixels**2) + \"%\")\n",
"\n", "\n",
"plt.figure(figsize=(15,10))\n", "plt.figure(figsize=(15,10))\n",
......
...@@ -10,8 +10,8 @@ if __name__ == "__main__": ...@@ -10,8 +10,8 @@ if __name__ == "__main__":
p_spec = lambda k: (1. / (k*correlation_length + 1) ** 4) p_spec = lambda k: (1. / (k*correlation_length + 1) ** 4)
nonlinearity = Tanh() nonlinearity = Tanh()
#nonlinearity = Linear() # nonlinearity = Linear()
#nonlinearity = Exponential() # nonlinearity = Exponential()
# Set up position space # Set up position space
s_space = ift.RGSpace(1024) s_space = ift.RGSpace(1024)
......
from .version import __version__ from .version import __version__
from . import dobj from . import dobj
from .domains import * from .domains import *
from .domain_tuple import DomainTuple from .domain_tuple import DomainTuple
from .operators import *
from .field import Field from .field import Field
from .operators import *
from .probing.utils import probe_with_posterior_samples, probe_diagonal, \ from .probing.utils import probe_with_posterior_samples, probe_diagonal, \
StatCalculator StatCalculator
from .minimization import * from .minimization import *
from .sugar import * from .sugar import *
...@@ -28,5 +22,5 @@ from .multi import * ...@@ -28,5 +22,5 @@ from .multi import *
__all__ = ["__version__", "dobj", "DomainTuple"] + \ __all__ = ["__version__", "dobj", "DomainTuple"] + \
domains.__all__ + operators.__all__ + minimization.__all__ + \ domains.__all__ + operators.__all__ + minimization.__all__ + \
["DomainTuple", "Field", "sqrt", "exp", "log"] + \ ["Field"] + sugar.__all__ + \
multi.__all__ multi.__all__
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
import numpy as np import numpy as np
from ..field import Field
from ..sugar import from_random from ..sugar import from_random
__all__ = ["check_value_gradient_consistency", __all__ = ["check_value_gradient_consistency",
...@@ -50,15 +49,16 @@ def check_value_gradient_consistency(E, tol=1e-6, ntries=100): ...@@ -50,15 +49,16 @@ def check_value_gradient_consistency(E, tol=1e-6, ntries=100):
E2 = _get_acceptable_energy(E) E2 = _get_acceptable_energy(E)
val = E.value val = E.value
dir = E2.position - E.position dir = E2.position - E.position
Enext = E2 # Enext = E2
dirnorm = dir.norm() dirnorm = dir.norm()
dirder = E.gradient.vdot(dir)/dirnorm
for i in range(50): for i in range(50):
Emid = E.at(E.position + 0.5*dir)
dirder = Emid.gradient.vdot(dir)/dirnorm
if abs((E2.value-val)/dirnorm-dirder) < tol: if abs((E2.value-val)/dirnorm-dirder) < tol:
break break
dir *= 0.5 dir *= 0.5
dirnorm *= 0.5 dirnorm *= 0.5
E2 = E2.at(E.position+dir) E2 = Emid
else: else:
raise ValueError("gradient and value seem inconsistent") raise ValueError("gradient and value seem inconsistent")
# E = Enext # E = Enext
...@@ -67,19 +67,20 @@ def check_value_gradient_consistency(E, tol=1e-6, ntries=100): ...@@ -67,19 +67,20 @@ def check_value_gradient_consistency(E, tol=1e-6, ntries=100):
def check_value_gradient_curvature_consistency(E, tol=1e-6, ntries=100): def check_value_gradient_curvature_consistency(E, tol=1e-6, ntries=100):
for _ in range(ntries): for _ in range(ntries):
E2 = _get_acceptable_energy(E) E2 = _get_acceptable_energy(E)
val = E.value
dir = E2.position - E.position dir = E2.position - E.position
Enext = E2 # Enext = E2
dirnorm = dir.norm() dirnorm = dir.norm()
dirder = E.gradient.vdot(dir)/dirnorm
dgrad = E.curvature(dir)/dirnorm
for i in range(50): for i in range(50):
gdiff = E2.gradient - E.gradient Emid = E.at(E.position + 0.5*dir)
if abs((E2.value-E.value)/dirnorm-dirder) < tol and \ dirder = Emid.gradient.vdot(dir)/dirnorm
dgrad = Emid.curvature(dir)/dirnorm
if abs((E2.value-val)/dirnorm-dirder) < tol and \
(abs((E2.gradient-E.gradient)/dirnorm-dgrad) < tol).all(): (abs((E2.gradient-E.gradient)/dirnorm-dgrad) < tol).all():
break break
dir *= 0.5 dir *= 0.5
dirnorm *= 0.5 dirnorm *= 0.5
E2 = E2.at(E.position+dir) E2 = Emid
else: else:
raise ValueError("gradient, value and curvature seem inconsistent") raise ValueError("gradient, value and curvature seem inconsistent")
# E = Enext # E = Enext
...@@ -721,6 +721,15 @@ class Field(object): ...@@ -721,6 +721,15 @@ class Field(object):
self._domain.__str__() + \ self._domain.__str__() + \
"\n- val = " + repr(self.val) "\n- val = " + repr(self.val)
def equivalent(self, other):
if self is other:
return True
if not isinstance(other, Field):
return False
if self._domain != other._domain:
return False
return (self._val == other._val).all()
for op in ["__add__", "__radd__", "__iadd__", for op in ["__add__", "__radd__", "__iadd__",
"__sub__", "__rsub__", "__isub__", "__sub__", "__rsub__", "__isub__",
"__mul__", "__rmul__", "__imul__", "__mul__", "__rmul__", "__imul__",
......
...@@ -15,10 +15,11 @@ from .energy import Energy ...@@ -15,10 +15,11 @@ from .energy import Energy
from .quadratic_energy import QuadraticEnergy from .quadratic_energy import QuadraticEnergy
from .line_energy import LineEnergy from .line_energy import LineEnergy
from .yango import Yango from .yango import Yango
from .energy_sum import EnergySum
__all__ = ["LineSearch", "LineSearchStrongWolfe", __all__ = ["LineSearch", "LineSearchStrongWolfe",
"IterationController", "GradientNormController", "IterationController", "GradientNormController",
"Minimizer", "ConjugateGradient", "NonlinearCG", "DescentMinimizer", "Minimizer", "ConjugateGradient", "NonlinearCG", "DescentMinimizer",
"SteepestDescent", "VL_BFGS", "RelaxedNewton", "ScipyMinimizer", "SteepestDescent", "VL_BFGS", "RelaxedNewton", "ScipyMinimizer",
"NewtonCG", "L_BFGS_B", "ScipyCG", "Energy", "QuadraticEnergy", "NewtonCG", "L_BFGS_B", "ScipyCG", "Energy", "QuadraticEnergy",
"LineEnergy", "L_BFGS", "Yango"] "LineEnergy", "L_BFGS", "EnergySum", "Yango"]
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from .energy import Energy
from ..utilities import memo
class EnergySum(Energy):
def __init__(self, position, energies, minimizer_controller=None,
preconditioner=None, precon_idx=None):
super(EnergySum, self).__init__(position=position)
self._energies = [energy.at(position) for energy in energies]
self._min_controller = minimizer_controller
self._preconditioner = preconditioner
self._precon_idx = precon_idx
def at(self, position):
return self.__class__(position, self._energies, self._min_controller,
self._preconditioner, self._precon_idx)
@property
@memo
def value(self):
res = self._energies[0].value
for e in self._energies[1:]:
res += e.value
return res
@property
@memo
def gradient(self):
res = self._energies[0].gradient.copy()
for e in self._energies[1:]:
res += e.gradient
return res.lock()
@property
@memo
def curvature(self):
res = self._energies[0].curvature
for e in self._energies[1:]:
res = res + e.curvature
if self._min_controller is None:
return res
precon = self._preconditioner
if precon is None and self._precon_idx is not None:
precon = self._energies[self._precon_idx].curvature
from ..operators.inversion_enabler import InversionEnabler
from .conjugate_gradient import ConjugateGradient
return InversionEnabler(
res, ConjugateGradient(self._min_controller), precon)
...@@ -97,5 +97,5 @@ class LineEnergy(object): ...@@ -97,5 +97,5 @@ class LineEnergy(object):
if abs(res.imag) / max(abs(res.real), 1.) > 1e-12: if abs(res.imag) / max(abs(res.real), 1.) > 1e-12:
from ..logger import logger from ..logger import logger
logger.warning("directional derivative has non-negligible " logger.warning("directional derivative has non-negligible "
"imaginary part:", res) "imaginary part: {}".format(res))
return res.real return res.real
...@@ -98,7 +98,7 @@ class ScipyMinimizer(Minimizer): ...@@ -98,7 +98,7 @@ class ScipyMinimizer(Minimizer):
r = opt.minimize(hlp.fun, x, method=self._method, jac=hlp.jac, r = opt.minimize(hlp.fun, x, method=self._method, jac=hlp.jac,
hessp=hessp, options=self._options, bounds=bounds) hessp=hessp, options=self._options, bounds=bounds)
if not r.success: if not r.success:
logger.error("Problem in Scipy minimization:", r.message) logger.error("Problem in Scipy minimization: {}".format(r.message))
return hlp._energy, IterationController.ERROR return hlp._energy, IterationController.ERROR
return hlp._energy, IterationController.CONVERGED return hlp._energy, IterationController.CONVERGED
......
...@@ -88,7 +88,7 @@ class Yango(Minimizer): ...@@ -88,7 +88,7 @@ class Yango(Minimizer):
return energy, controller.ERROR return energy, controller.ERROR
# Try 1D Newton Step # Try 1D Newton Step
energy, success = self._line_searcher.perform_line_search( energy, success = self._line_searcher.perform_line_search(
energy, (rr/rAr)*r, f_k_minus_1) energy, (rr/rAr)*r, f_k_minus_1)
else: else:
a = (rAr*rp - rAp*rr)/det a = (rAr*rp - rAp*rr)/det
b = (pAp*rr - pAr*rp)/det b = (pAp*rr - pAr*rp)/det
......
...@@ -29,6 +29,8 @@ class MultiField(object): ...@@ -29,6 +29,8 @@ class MultiField(object):
val : dict val : dict
""" """
self._val = val self._val = val
self._domain = MultiDomain.make(
{key: val.domain for key, val in self._val.items()})
def __getitem__(self, key): def __getitem__(self, key):
return self._val[key] return self._val[key]
...@@ -44,8 +46,7 @@ class MultiField(object): ...@@ -44,8 +46,7 @@ class MultiField(object):
@property @property
def domain(self): def domain(self):
return MultiDomain.make( return self._domain
{key: val.domain for key, val in self._val.items()})
@property @property
def dtype(self): def dtype(self):
...@@ -71,7 +72,7 @@ class MultiField(object): ...@@ -71,7 +72,7 @@ class MultiField(object):
return self return self
def _check_domain(self, other): def _check_domain(self, other):
if other.domain != self.domain: if other._domain != self._domain:
raise ValueError("domains are incompatible.") raise ValueError("domains are incompatible.")
def vdot(self, x): def vdot(self, x):
...@@ -147,6 +148,17 @@ class MultiField(object): ...@@ -147,6 +148,17 @@ class MultiField(object):
return MultiField({key: sub_field.conjugate() return MultiField({key: sub_field.conjugate()
for key, sub_field in self.items()}) for key, sub_field in self.items()})
def equivalent(self, other):
if self is other:
return True
if not isinstance(other, MultiField):
return False
if self._domain != other._domain:
return False
for key, val in self._val.items():
if not val.equivalent(other[key]):
return False
return True
for op in ["__add__", "__radd__", "__iadd__", for op in ["__add__", "__radd__", "__iadd__",
"__sub__", "__rsub__", "__isub__", "__sub__", "__rsub__", "__isub__",
......
...@@ -130,16 +130,22 @@ class LinearOperator(NiftyMetaBase()): ...@@ -130,16 +130,22 @@ class LinearOperator(NiftyMetaBase()):
def __mul__(self, other): def __mul__(self, other):
from .chain_operator import ChainOperator from .chain_operator import ChainOperator
if np.isscalar(other) and other == 1.:
return self
other = self._toOperator(other, self.domain) other = self._toOperator(other, self.domain)
return ChainOperator.make([self, other]) return ChainOperator.make([self, other])
def __rmul__(self, other): def __rmul__(self, other):
from .chain_operator import ChainOperator from .chain_operator import ChainOperator
if np.isscalar(other) and other == 1.:
return self
other = self._toOperator(other, self.target) other = self._toOperator(other, self.target)
return ChainOperator.make([other, self]) return ChainOperator.make([other, self])
def __add__(self, other): def __add__(self, other):
from .sum_operator import SumOperator from .sum_operator import SumOperator
if np.isscalar(other) and other == 0.:
return self
other = self._toOperator(other, self.domain) other = self._toOperator(other, self.domain)
return SumOperator.make([self, other], [False, False]) return SumOperator.make([self, other], [False, False])
...@@ -148,6 +154,8 @@ class LinearOperator(NiftyMetaBase()): ...@@ -148,6 +154,8 @@ class LinearOperator(NiftyMetaBase()):
def __sub__(self, other): def __sub__(self, other):
from .sum_operator import SumOperator from .sum_operator import SumOperator
if np.isscalar(other) and other == 0.:
return self
other = self._toOperator(other, self.domain) other = self._toOperator(other, self.domain)
return SumOperator.make([self, other], [False, True]) return SumOperator.make([self, other], [False, True])
......
...@@ -70,6 +70,7 @@ def get_signal_variance(spec, space): ...@@ -70,6 +70,7 @@ def get_signal_variance(spec, space):
k_field = dist(field) k_field = dist(field)
return k_field.weight(2).sum() return k_field.weight(2).sum()
def _single_power_analyze(field, idx, binbounds): def _single_power_analyze(field, idx, binbounds):
power_domain = PowerSpace(field.domain[idx], binbounds) power_domain = PowerSpace(field.domain[idx], binbounds)
pd = PowerDistributor(field.domain, power_domain, idx) pd = PowerDistributor(field.domain, power_domain, idx)
......
...@@ -25,6 +25,7 @@ from test.common import expand ...@@ -25,6 +25,7 @@ from test.common import expand
dom = ift.makeDomain({"d1": ift.RGSpace(10)}) dom = ift.makeDomain({"d1": ift.RGSpace(10)})
class Test_Functionality(unittest.TestCase): class Test_Functionality(unittest.TestCase):
def test_vdot(self): def test_vdot(self):
f1 = ift.from_random("normal", domain=dom, dtype=np.complex128) f1 = ift.from_random("normal", domain=dom, dtype=np.complex128)
...@@ -53,7 +54,8 @@ class Test_Functionality(unittest.TestCase): ...@@ -53,7 +54,8 @@ class Test_Functionality(unittest.TestCase):
assert_equal(val.local_data, f2[key].local_data) assert_equal(val.local_data, f2[key].local_data)
def test_blockdiagonal(self): def test_blockdiagonal(self):
op = ift.BlockDiagonalOperator({"d1": ift.ScalingOperator(20., dom["d1"])}) op = ift.BlockDiagonalOperator({"d1":
ift.ScalingOperator(20., dom["d1"])})
op2 = op*op op2 = op*op
ift.extra.consistency_check(op2) ift.extra.consistency_check(op2)
assert_equal(type(op2), ift.BlockDiagonalOperator) assert_equal(type(op2), ift.BlockDiagonalOperator)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment