Commit 9587d27e authored by Philipp Arras's avatar Philipp Arras
Browse files

Add preimage script

parent 75245a67
Pipeline #64468 passed with stages
in 9 minutes and 18 seconds
......@@ -24,15 +24,12 @@ from . import dobj, utilities
from .domain_tuple import DomainTuple
from .domains.power_space import PowerSpace
from .field import Field
from .linearization import Linearization
from .logger import logger
from import Energy
from .multi_domain import MultiDomain
from .multi_field import MultiField
from .operators.block_diagonal_operator import BlockDiagonalOperator
from .operators.diagonal_operator import DiagonalOperator
from .operators.distributors import PowerDistributor
from .operators.energy_operators import EnergyOperator
from .operators.operator import Operator
from .plot import Plot
......@@ -43,7 +40,8 @@ __all__ = ['PS_field', 'power_analyze', 'create_power_operator',
'sin', 'cos', 'tan', 'sinh', 'cosh', 'log10',
'absolute', 'one_over', 'clip', 'sinc',
'conjugate', 'get_signal_variance', 'makeOp', 'domain_union',
'get_default_codomain', 'single_plot', 'exec_time']
'get_default_codomain', 'single_plot', 'exec_time',
def PS_field(pspace, func):
......@@ -458,6 +456,9 @@ def single_plot(field, **kwargs):
def exec_time(obj, want_metric=True):
"""Times the execution time of an operator or an energy."""
from .linearization import Linearization
from import Energy
from .operators.energy_operators import EnergyOperator
if isinstance(obj, Energy):
t0 = time()*obj.position)
......@@ -503,3 +504,29 @@ def exec_time(obj, want_metric=True):
print('Metric apply:', time() - t0)
raise TypeError
def calculate_position(operator, output):
"""Finds approximate preimage of an operator for a given output."""
from .minimization.descent_minimizers import NewtonCG
from .minimization.iteration_controllers import GradientNormController
from .minimization.metric_gaussian_kl import MetricGaussianKL
from .operators.scaling_operator import ScalingOperator
from .operators.energy_operators import GaussianEnergy, StandardHamiltonian
if not isinstance(operator, Operator):
raise TypeError
if output.domain !=
raise TypeError
cov = 1e-3*output.to_global_data().max()**2
invcov = ScalingOperator(cov, output.domain).inverse
d = output + invcov.draw_sample(from_inverse=True)
lh = GaussianEnergy(d, invcov)(operator)
H = StandardHamiltonian(
lh, ic_samp=GradientNormController(iteration_limit=200))
pos = 0.1*from_random('normal', operator.domain)
minimizer = NewtonCG(GradientNormController(iteration_limit=10))
for ii in range(3):
kl = MetricGaussianKL(pos, H, 3, mirror_samples=True)
kl, _ = minimizer(kl)
pos = kl.position
return pos
......@@ -35,3 +35,26 @@ def test_get_signal_variance():
t[k == 0] = 1.
return t
assert_equal(ift.get_signal_variance(spec2, hspace), 1/9.)
def test_exec_time():
dom = ift.RGSpace(12, harmonic=True)
op = ift.HarmonicTransformOperator(dom)
op1 = op.exp()
lh = ift.GaussianEnergy( @ op1
ic = ift.GradientNormController(iteration_limit=2)
ham = ift.StandardHamiltonian(lh, ic_samp=ic)
kl = ift.MetricGaussianKL(ift.full(ham.domain, 0.), ham, 1)
ops = [op, op1, lh, ham, kl]
for oo in ops:
for wm in [True, False]:
ift.exec_time(oo, wm)
def test_calc_pos():
dom = ift.RGSpace(12, harmonic=True)
op = ift.HarmonicTransformOperator(dom).exp()
fld = op(0.1*ift.from_random('normal', op.domain))
pos = ift.calculate_position(op, fld)
ift.extra.assert_allclose(op(pos), fld, 1e-1, 1e-1)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment