Commit 2ec7955a authored by Philipp Arras's avatar Philipp Arras
Browse files

Support multifield and complex output for find_position

parent f40a1834
Pipeline #75315 failed with stages
in 39 seconds
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
# #
# Copyright(C) 2013-2019 Max-Planck-Society # Copyright(C) 2013-2020 Max-Planck-Society
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
...@@ -20,21 +20,24 @@ from time import time ...@@ -20,21 +20,24 @@ from time import time
import numpy as np import numpy as np
from .logger import logger from . import pointwise, utilities
from . import utilities
from .domain_tuple import DomainTuple from .domain_tuple import DomainTuple
from .domains.power_space import PowerSpace from .domains.power_space import PowerSpace
from .field import Field from .field import Field
from .logger import logger
from .minimization.descent_minimizers import NewtonCG
from .minimization.iteration_controllers import GradientNormController
from .minimization.metric_gaussian_kl import MetricGaussianKL
from .multi_domain import MultiDomain from .multi_domain import MultiDomain
from .multi_field import MultiField from .multi_field import MultiField
from .operators.block_diagonal_operator import BlockDiagonalOperator from .operators.block_diagonal_operator import BlockDiagonalOperator
from .operators.diagonal_operator import DiagonalOperator from .operators.diagonal_operator import DiagonalOperator
from .operators.distributors import PowerDistributor from .operators.distributors import PowerDistributor
from .operators.energy_operators import GaussianEnergy, StandardHamiltonian
from .operators.operator import Operator from .operators.operator import Operator
from .operators.sampling_enabler import SamplingDtypeSetter
from .operators.scaling_operator import ScalingOperator from .operators.scaling_operator import ScalingOperator
from .plot import Plot from .plot import Plot
from . import pointwise
__all__ = ['PS_field', 'power_analyze', 'create_power_operator', __all__ = ['PS_field', 'power_analyze', 'create_power_operator',
'create_harmonic_smoothing_operator', 'from_random', 'create_harmonic_smoothing_operator', 'from_random',
...@@ -491,31 +494,28 @@ def exec_time(obj, want_metric=True): ...@@ -491,31 +494,28 @@ def exec_time(obj, want_metric=True):
def calculate_position(operator, output): def calculate_position(operator, output):
"""Finds approximate preimage of an operator for a given output.""" """Finds approximate preimage of an operator for a given output."""
from .minimization.descent_minimizers import NewtonCG
from .minimization.iteration_controllers import GradientNormController
from .minimization.metric_gaussian_kl import MetricGaussianKL
from .operators.scaling_operator import ScalingOperator
from .operators.energy_operators import GaussianEnergy, StandardHamiltonian
if not isinstance(operator, Operator): if not isinstance(operator, Operator):
raise TypeError raise TypeError
if output.domain != operator.target: if output.domain != operator.target:
raise TypeError raise TypeError
if isinstance(output, MultiField): if isinstance(output, MultiField):
cov = 1e-3*max([vv.max() for vv in output.val.values()])**2 cov = 1e-3*max([np.max(np.abs(vv)) for vv in output.val.values()])**2
invcov = ScalingOperator(output.domain, cov).inverse invcov = ScalingOperator(output.domain, cov).inverse
dtype = list(set([ff.dtype for ff in output.values()])) dtype = list(set([ff.dtype for ff in output.values()]))
if len(dtype) != 1: if len(dtype) != 1:
raise ValueError('Only MultiFields with one dtype supported.') raise ValueError('Only MultiFields with one dtype supported.')
dtype = dtype[0] dtype = dtype[0]
else: else:
cov = 1e-3*output.val.max()**2 cov = 1e-3*np.max(np.abs(output.val))**2
dtype = output.dtype dtype = output.dtype
invcov = ScalingOperator(output.domain, cov).inverse invcov = ScalingOperator(output.domain, cov).inverse
d = output + invcov.draw_sample(dtype, from_inverse=True) invcov = SamplingDtypeSetter(invcov, output.dtype)
invcov = SamplingDtypeSetter(invcov, output.dtype)
d = output + invcov.draw_sample(from_inverse=True)
lh = GaussianEnergy(d, invcov) @ operator lh = GaussianEnergy(d, invcov) @ operator
H = StandardHamiltonian( H = StandardHamiltonian(
lh, ic_samp=GradientNormController(iteration_limit=200)) lh, ic_samp=GradientNormController(iteration_limit=200))
pos = 0.1*from_random('normal', operator.domain) pos = 0.1*from_random(operator.domain)
minimizer = NewtonCG(GradientNormController(iteration_limit=10, name='findpos')) minimizer = NewtonCG(GradientNormController(iteration_limit=10, name='findpos'))
for ii in range(3): for ii in range(3):
logger.info(f'Start iteration {ii+1}/3') logger.info(f'Start iteration {ii+1}/3')
......
...@@ -52,9 +52,18 @@ def test_exec_time(): ...@@ -52,9 +52,18 @@ def test_exec_time():
ift.exec_time(oo, wm) ift.exec_time(oo, wm)
def test_calc_pos(): import pytest
pmp = pytest.mark.parametrize
@pmp('mf', [False, True])
@pmp('cplx', [False, True])
def test_calc_pos(mf, cplx):
dom = ift.RGSpace(12, harmonic=True) dom = ift.RGSpace(12, harmonic=True)
op = ift.HarmonicTransformOperator(dom).ptw("exp") op = ift.HarmonicTransformOperator(dom).ptw("exp")
if mf:
op = op.ducktape_left('foo')
dom = ift.makeDomain({'': dom})
if cplx:
op = op + 1j*op
fld = op(0.1 * ift.from_random(op.domain, 'normal')) fld = op(0.1 * ift.from_random(op.domain, 'normal'))
pos = ift.calculate_position(op, fld) pos = ift.calculate_position(op, fld)
ift.extra.assert_allclose(op(pos), fld, 1e-1, 1e-1) ift.extra.assert_allclose(op(pos), fld, 1e-1, 1e-1)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment