diff --git a/nifty5/sugar.py b/nifty5/sugar.py index dc3a33d4edb51fe6c1d1aca2284b919f62c2dddf..bf0fe22a3c45a1fad447a7110e323c67980eeeb0 100644 --- a/nifty5/sugar.py +++ b/nifty5/sugar.py @@ -24,15 +24,12 @@ from . import dobj, utilities from .domain_tuple import DomainTuple from .domains.power_space import PowerSpace from .field import Field -from .linearization import Linearization from .logger import logger -from .minimization.energy import Energy from .multi_domain import MultiDomain from .multi_field import MultiField from .operators.block_diagonal_operator import BlockDiagonalOperator from .operators.diagonal_operator import DiagonalOperator from .operators.distributors import PowerDistributor -from .operators.energy_operators import EnergyOperator from .operators.operator import Operator from .plot import Plot @@ -43,7 +40,8 @@ __all__ = ['PS_field', 'power_analyze', 'create_power_operator', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'log10', 'absolute', 'one_over', 'clip', 'sinc', 'conjugate', 'get_signal_variance', 'makeOp', 'domain_union', - 'get_default_codomain', 'single_plot', 'exec_time'] + 'get_default_codomain', 'single_plot', 'exec_time', + 'calculate_position'] def PS_field(pspace, func): @@ -458,6 +456,9 @@ def single_plot(field, **kwargs): def exec_time(obj, want_metric=True): """Times the execution time of an operator or an energy.""" + from .linearization import Linearization + from .minimization.energy import Energy + from .operators.energy_operators import EnergyOperator if isinstance(obj, Energy): t0 = time() obj.at(0.99*obj.position) @@ -503,3 +504,29 @@ def exec_time(obj, want_metric=True): print('Metric apply:', time() - t0) else: raise TypeError + + +def calculate_position(operator, output): + """Finds approximate preimage of an operator for a given output.""" + from .minimization.descent_minimizers import NewtonCG + from .minimization.iteration_controllers import GradientNormController + from .minimization.metric_gaussian_kl import MetricGaussianKL + from .operators.scaling_operator import ScalingOperator + from .operators.energy_operators import GaussianEnergy, StandardHamiltonian + if not isinstance(operator, Operator): + raise TypeError + if output.domain != operator.target: + raise TypeError + cov = 1e-3*output.to_global_data().max()**2 + invcov = ScalingOperator(cov, output.domain).inverse + d = output + invcov.draw_sample(from_inverse=True) + lh = GaussianEnergy(d, invcov)(operator) + H = StandardHamiltonian( + lh, ic_samp=GradientNormController(iteration_limit=200)) + pos = 0.1*from_random('normal', operator.domain) + minimizer = NewtonCG(GradientNormController(iteration_limit=10)) + for ii in range(3): + kl = MetricGaussianKL(pos, H, 3, mirror_samples=True) + kl, _ = minimizer(kl) + pos = kl.position + return pos diff --git a/test/test_sugar.py b/test/test_sugar.py index e354723b1470d91808d5932d8d1f700967abf651..dd4c990922a76ce2c0cfcf1db0fb0446ff232824 100644 --- a/test/test_sugar.py +++ b/test/test_sugar.py @@ -35,3 +35,26 @@ def test_get_signal_variance(): t[k == 0] = 1. return t assert_equal(ift.get_signal_variance(spec2, hspace), 1/9.) + + +def test_exec_time(): + dom = ift.RGSpace(12, harmonic=True) + op = ift.HarmonicTransformOperator(dom) + op1 = op.exp() + lh = ift.GaussianEnergy(domain=op.target) @ op1 + ic = ift.GradientNormController(iteration_limit=2) + ham = ift.StandardHamiltonian(lh, ic_samp=ic) + kl = ift.MetricGaussianKL(ift.full(ham.domain, 0.), ham, 1) + ops = [op, op1, lh, ham, kl] + for oo in ops: + for wm in [True, False]: + ift.exec_time(oo, wm) + + +def test_calc_pos(): + np.random.seed(42) + dom = ift.RGSpace(12, harmonic=True) + op = ift.HarmonicTransformOperator(dom).exp() + fld = op(0.1*ift.from_random('normal', op.domain)) + pos = ift.calculate_position(op, fld) + ift.extra.assert_allclose(op(pos), fld, 1e-1, 1e-1)