Commit 496b9d07 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

merge master and more features

parents 508800eb 846cc966
...@@ -117,13 +117,13 @@ if __name__ == "__main__": ...@@ -117,13 +117,13 @@ if __name__ == "__main__":
# Solving the problem analytically # Solving the problem analytically
m0 = D0.inverse_times(j) m0 = D0.inverse_times(j)
sample_variance = Field(sh.domain, val=0. + 0j) sample_variance = Field(sh.domain, val=0.)
sample_mean = Field(sh.domain, val=0. + 0j) sample_mean = Field(sh.domain, val=0.)
# sampling the uncertainty map # sampling the uncertainty map
n_samples = 1 n_samples = 10
for i in range(n_samples): for i in range(n_samples):
sample = sugar.generate_posterior_sample(m0, D0) sample = fft(sugar.generate_posterior_sample(0., D0))
sample_variance += sample**2 sample_variance += sample**2
sample_mean += sample sample_mean += sample
variance = sample_variance/n_samples - (sample_mean/n_samples) variance = (sample_variance - sample_mean**2)/n_samples
...@@ -24,7 +24,7 @@ from .field import Field ...@@ -24,7 +24,7 @@ from .field import Field
__all__ = ['cos', 'sin', 'cosh', 'sinh', 'tan', 'tanh', 'arccos', 'arcsin', __all__ = ['cos', 'sin', 'cosh', 'sinh', 'tan', 'tanh', 'arccos', 'arcsin',
'arccosh', 'arcsinh', 'arctan', 'arctanh', 'sqrt', 'exp', 'log', 'arccosh', 'arcsinh', 'arctan', 'arctanh', 'sqrt', 'exp', 'log',
'conjugate', 'clipped_exp', 'limited_exp'] 'conjugate', 'clipped_exp', 'limited_exp', 'limited_exp_deriv']
def _math_helper(x, function): def _math_helper(x, function):
...@@ -101,15 +101,28 @@ def clipped_exp(x): ...@@ -101,15 +101,28 @@ def clipped_exp(x):
def limited_exp(x): def limited_exp(x):
thr = 200 return _math_helper(x, _limited_exp_helper)
expthr = np.exp(thr)
return _math_helper(x, lambda z: _limited_exp_helper(z, thr, expthr)) def _limited_exp_helper(x):
thr = 200.
mask = x>thr
if np.count_nonzero(mask) == 0:
return np.exp(x)
result = ((1.-thr) + x)*np.exp(thr)
result[~mask] = np.exp(x[~mask])
return result
def _limited_exp_helper(x, thr, expthr): def limited_exp_deriv(x):
mask = (x > thr) return _math_helper(x, _limited_exp_deriv_helper)
result = np.exp(x)
result[mask] = ((1-thr) + x[mask])*expthr def _limited_exp_deriv_helper(x):
thr = 200.
mask = x>thr
if np.count_nonzero(mask) == 0:
return np.exp(x)
result = np.empty_like(x)
result[mask] = np.exp(thr)
result[~mask] = np.exp(x[~mask])
return result return result
......
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from nifty.nifty_meta import NiftyMeta from ..nifty_meta import NiftyMeta
from nifty.energies.memoization import memo from .memoization import memo
from keepers import Loggable from keepers import Loggable
from future.utils import with_metaclass from future.utils import with_metaclass
......
...@@ -330,7 +330,7 @@ class Field(Loggable, Versionable, object): ...@@ -330,7 +330,7 @@ class Field(Loggable, Versionable, object):
Returns Returns
------- -------
out : Field out : Field
The output object. It's domain is a PowerSpace and it contains The output object. Its domain is a PowerSpace and it contains
the power spectrum of 'self's field. the power spectrum of 'self's field.
See Also See Also
...@@ -1123,7 +1123,7 @@ class Field(Loggable, Versionable, object): ...@@ -1123,7 +1123,7 @@ class Field(Loggable, Versionable, object):
else: else:
# create a diagonal operator which is capable of taking care of the # create a diagonal operator which is capable of taking care of the
# axes-matching # axes-matching
from nifty.operators.diagonal_operator import DiagonalOperator from .operators.diagonal_operator import DiagonalOperator
diagonal = y.val.conjugate() diagonal = y.val.conjugate()
diagonalOperator = DiagonalOperator(domain=y.domain, diagonalOperator = DiagonalOperator(domain=y.domain,
diagonal=diagonal, diagonal=diagonal,
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from nifty.domain_object import DomainObject from ..domain_object import DomainObject
class FieldType(DomainObject): class FieldType(DomainObject):
......
from nifty.operators.endomorphic_operator import EndomorphicOperator from ...operators.endomorphic_operator import EndomorphicOperator
from nifty.operators.invertible_operator_mixin import InvertibleOperatorMixin from ...operators.invertible_operator_mixin import InvertibleOperatorMixin
from nifty.operators.diagonal_operator import DiagonalOperator from ...operators.diagonal_operator import DiagonalOperator
class CriticalPowerCurvature(InvertibleOperatorMixin, EndomorphicOperator): class CriticalPowerCurvature(InvertibleOperatorMixin, EndomorphicOperator):
......
from ...energies.energy import Energy
from ...operators.smoothness_operator import SmoothnessOperator
from . import CriticalPowerCurvature
from ...energies.memoization import memo
from nifty.energies.energy import Energy from ...sugar import generate_posterior_sample
from nifty.operators.smoothness_operator import SmoothnessOperator from ... import Field, exp
from nifty.library.critical_filter import CriticalPowerCurvature
from nifty.energies.memoization import memo
from nifty.sugar import generate_posterior_sample
from nifty import Field, exp
class CriticalPowerEnergy(Energy): class CriticalPowerEnergy(Energy):
......
from nifty.operators import EndomorphicOperator,\ from ...operators import EndomorphicOperator,\
InvertibleOperatorMixin InvertibleOperatorMixin
from nifty.energies.memoization import memo from ...energies.memoization import memo
from nifty.basic_arithmetics import clipped_exp from ...basic_arithmetics import clipped_exp
from nifty.sugar import create_composed_fft_operator from ...sugar import create_composed_fft_operator
class LogNormalWienerFilterCurvature(InvertibleOperatorMixin, class LogNormalWienerFilterCurvature(InvertibleOperatorMixin,
......
from nifty.energies.energy import Energy from ...energies.energy import Energy
from nifty.energies.memoization import memo from ...energies.memoization import memo
from nifty.library.log_normal_wiener_filter import \ from . import LogNormalWienerFilterCurvature
LogNormalWienerFilterCurvature from ...sugar import create_composed_fft_operator
from nifty.sugar import create_composed_fft_operator
class LogNormalWienerFilterEnergy(Energy): class LogNormalWienerFilterEnergy(Energy):
......
from nifty.operators import EndomorphicOperator,\ from ...operators import EndomorphicOperator,\
InvertibleOperatorMixin InvertibleOperatorMixin
......
from nifty.energies.energy import Energy from ...energies.energy import Energy
from nifty.energies.memoization import memo from ...energies.memoization import memo
from nifty.library.wiener_filter import WienerFilterCurvature from . import WienerFilterCurvature
class WienerFilterEnergy(Energy): class WienerFilterEnergy(Energy):
......
...@@ -20,13 +20,16 @@ from __future__ import print_function ...@@ -20,13 +20,16 @@ from __future__ import print_function
from .iteration_controller import IterationController from .iteration_controller import IterationController
class DefaultIterationController(IterationController): class DefaultIterationController(IterationController):
def __init__ (self, tol_gradnorm=None, tol_rel_gradnorm=None, def __init__ (self, tol_abs_gradnorm=None, tol_rel_gradnorm=None,
convergence_level=1, iteration_limit=None): convergence_level=1, iteration_limit=None, name=None,
verbose=None):
super(DefaultIterationController, self).__init__() super(DefaultIterationController, self).__init__()
self._tol_gradnorm = tol_gradnorm self._tol_abs_gradnorm = tol_abs_gradnorm
self._tol_rel_gradnorm = tol_rel_gradnorm self._tol_rel_gradnorm = tol_rel_gradnorm
self._convergence_level = convergence_level self._convergence_level = convergence_level
self._iteration_limit = iteration_limit self._iteration_limit = iteration_limit
self._name = name
self._verbose = verbose
def start(self, energy): def start(self, energy):
self._itcount = -1 self._itcount = -1
...@@ -38,16 +41,29 @@ class DefaultIterationController(IterationController): ...@@ -38,16 +41,29 @@ class DefaultIterationController(IterationController):
def check(self, energy): def check(self, energy):
self._itcount += 1 self._itcount += 1
print("iteration",self._itcount,"gradnorm",energy.gradient_norm,"level",self._ccount, energy.value)
if self._iteration_limit is not None: if self._tol_abs_gradnorm is not None:
if self._itcount >= self._iteration_limit: if energy.gradient_norm <= self._tol_abs_gradnorm:
return self.CONVERGED
if self._tol_gradnorm is not None:
if energy.gradient_norm <= self._tol_gradnorm:
self._ccount += 1 self._ccount += 1
if self._tol_rel_gradnorm is not None: if self._tol_rel_gradnorm is not None:
if energy.gradient_norm <= self._tol_rel_gradnorm_now: if energy.gradient_norm <= self._tol_rel_gradnorm_now:
self._ccount += 1 self._ccount += 1
# report
if self._verbose:
msg = ""
if self._name is not None:
msg += self._name+":"
msg += " Iteration #" + str(self._itcount)
msg += " gradnorm=" + str(energy.gradient_norm)
msg += " convergence level=" + str(self._ccount)
print (msg)
self.logger.info(msg)
# Are we done?
if self._iteration_limit is not None:
if self._itcount >= self._iteration_limit:
return self.CONVERGED
if self._ccount >= self._convergence_level: if self._ccount >= self._convergence_level:
return self.CONVERGED return self.CONVERGED
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
from __future__ import division from __future__ import division
import abc import abc
import numpy as np import numpy as np
from .minimizer import Minimizer from .minimizer import Minimizer
......
...@@ -25,7 +25,8 @@ import numpy as np ...@@ -25,7 +25,8 @@ import numpy as np
from keepers import Loggable from keepers import Loggable
from future.utils import with_metaclass from future.utils import with_metaclass
class IterationController(with_metaclass(NiftyMeta, type('NewBase', (Loggable, object), {}))): class IterationController(with_metaclass(NiftyMeta, type('NewBase',
(Loggable, object), {}))):
"""The abstract base class for all iteration controllers. """The abstract base class for all iteration controllers.
An iteration controller is an object that monitors the progress of a An iteration controller is an object that monitors the progress of a
minimization iteration. At the begin of the minimization, its start() minimization iteration. At the begin of the minimization, its start()
......
...@@ -20,7 +20,7 @@ import abc ...@@ -20,7 +20,7 @@ import abc
from keepers import Loggable from keepers import Loggable
from nifty import LineEnergy from ...energies import LineEnergy
from future.utils import with_metaclass from future.utils import with_metaclass
...@@ -28,7 +28,7 @@ class LineSearch(with_metaclass(abc.ABCMeta, with_metaclass(abc.ABCMeta, type('N ...@@ -28,7 +28,7 @@ class LineSearch(with_metaclass(abc.ABCMeta, with_metaclass(abc.ABCMeta, type('N
"""Class for determining the optimal step size along some descent direction. """Class for determining the optimal step size along some descent direction.
Initialize the line search procedure which can be used by a specific line Initialize the line search procedure which can be used by a specific line
search method. Its finds the step size in a specific direction in the search method. It finds the step size in a specific direction in the
minimization process. minimization process.
Attributes Attributes
......
...@@ -96,7 +96,7 @@ def cast_axis_to_tuple(axis, length=None): ...@@ -96,7 +96,7 @@ def cast_axis_to_tuple(axis, length=None):
def parse_domain(domain): def parse_domain(domain):
from nifty.domain_object import DomainObject from .domain_object import DomainObject
if domain is None: if domain is None:
domain = () domain = ()
elif isinstance(domain, DomainObject): elif isinstance(domain, DomainObject):
......
...@@ -17,10 +17,9 @@ ...@@ -17,10 +17,9 @@
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from builtins import object from builtins import object
from nifty.minimization import ConjugateGradient from ...minimization import ConjugateGradient
from ...field import Field
from nifty.field import Field from ...energies import QuadraticEnergy
from nifty.energies import QuadraticEnergy
class InvertibleOperatorMixin(object): class InvertibleOperatorMixin(object):
......
...@@ -17,11 +17,11 @@ ...@@ -17,11 +17,11 @@
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
import numpy as np import numpy as np
from nifty.field import Field from ...field import Field
from nifty.spaces.power_space import PowerSpace from ...spaces.power_space import PowerSpace
from nifty.operators.endomorphic_operator import EndomorphicOperator from ..endomorphic_operator import EndomorphicOperator
from nifty import sqrt from ... import sqrt
import nifty.nifty_utilities as utilities from ... import nifty_utilities as utilities
class LaplaceOperator(EndomorphicOperator): class LaplaceOperator(EndomorphicOperator):
......
...@@ -18,11 +18,11 @@ ...@@ -18,11 +18,11 @@
from builtins import str from builtins import str
import abc import abc
from nifty.nifty_meta import NiftyMeta from ...nifty_meta import NiftyMeta
from keepers import Loggable from keepers import Loggable
from nifty.field import Field from ...field import Field
import nifty.nifty_utilities as utilities from ... import nifty_utilities as utilities
from future.utils import with_metaclass from future.utils import with_metaclass
......
...@@ -19,9 +19,8 @@ ...@@ -19,9 +19,8 @@
from builtins import range from builtins import range
import numpy as np import numpy as np
from nifty.field import Field from ...field import Field
from ..endomorphic_operator import EndomorphicOperator
from nifty.operators.endomorphic_operator import EndomorphicOperator
class ProjectionOperator(EndomorphicOperator): class ProjectionOperator(EndomorphicOperator):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment