Commit 2ec7955a by Philipp Arras

### Support multifield and complex output for find_position

parent f40a1834
Pipeline #75315 failed with stages
in 39 seconds
 ... ... @@ -11,7 +11,7 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . # # Copyright(C) 2013-2019 Max-Planck-Society # Copyright(C) 2013-2020 Max-Planck-Society # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. ... ... @@ -20,21 +20,24 @@ from time import time import numpy as np from .logger import logger from . import utilities from . import pointwise, utilities from .domain_tuple import DomainTuple from .domains.power_space import PowerSpace from .field import Field from .logger import logger from .minimization.descent_minimizers import NewtonCG from .minimization.iteration_controllers import GradientNormController from .minimization.metric_gaussian_kl import MetricGaussianKL from .multi_domain import MultiDomain from .multi_field import MultiField from .operators.block_diagonal_operator import BlockDiagonalOperator from .operators.diagonal_operator import DiagonalOperator from .operators.distributors import PowerDistributor from .operators.energy_operators import GaussianEnergy, StandardHamiltonian from .operators.operator import Operator from .operators.sampling_enabler import SamplingDtypeSetter from .operators.scaling_operator import ScalingOperator from .plot import Plot from . import pointwise __all__ = ['PS_field', 'power_analyze', 'create_power_operator', 'create_harmonic_smoothing_operator', 'from_random', ... ... @@ -491,31 +494,28 @@ def exec_time(obj, want_metric=True): def calculate_position(operator, output): """Finds approximate preimage of an operator for a given output.""" from .minimization.descent_minimizers import NewtonCG from .minimization.iteration_controllers import GradientNormController from .minimization.metric_gaussian_kl import MetricGaussianKL from .operators.scaling_operator import ScalingOperator from .operators.energy_operators import GaussianEnergy, StandardHamiltonian if not isinstance(operator, Operator): raise TypeError if output.domain != operator.target: raise TypeError if isinstance(output, MultiField): cov = 1e-3*max([vv.max() for vv in output.val.values()])**2 cov = 1e-3*max([np.max(np.abs(vv)) for vv in output.val.values()])**2 invcov = ScalingOperator(output.domain, cov).inverse dtype = list(set([ff.dtype for ff in output.values()])) if len(dtype) != 1: raise ValueError('Only MultiFields with one dtype supported.') dtype = dtype[0] else: cov = 1e-3*output.val.max()**2 cov = 1e-3*np.max(np.abs(output.val))**2 dtype = output.dtype invcov = ScalingOperator(output.domain, cov).inverse d = output + invcov.draw_sample(dtype, from_inverse=True) invcov = SamplingDtypeSetter(invcov, output.dtype) invcov = SamplingDtypeSetter(invcov, output.dtype) d = output + invcov.draw_sample(from_inverse=True) lh = GaussianEnergy(d, invcov) @ operator H = StandardHamiltonian( lh, ic_samp=GradientNormController(iteration_limit=200)) pos = 0.1*from_random('normal', operator.domain) pos = 0.1*from_random(operator.domain) minimizer = NewtonCG(GradientNormController(iteration_limit=10, name='findpos')) for ii in range(3): logger.info(f'Start iteration {ii+1}/3') ... ...
 ... ... @@ -52,9 +52,18 @@ def test_exec_time(): ift.exec_time(oo, wm) def test_calc_pos(): import pytest pmp = pytest.mark.parametrize @pmp('mf', [False, True]) @pmp('cplx', [False, True]) def test_calc_pos(mf, cplx): dom = ift.RGSpace(12, harmonic=True) op = ift.HarmonicTransformOperator(dom).ptw("exp") if mf: op = op.ducktape_left('foo') dom = ift.makeDomain({'': dom}) if cplx: op = op + 1j*op fld = op(0.1 * ift.from_random(op.domain, 'normal')) pos = ift.calculate_position(op, fld) ift.extra.assert_allclose(op(pos), fld, 1e-1, 1e-1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!