From 374675e9a35a80f2bd4f3ad8fc78cbf5fd93b6a2 Mon Sep 17 00:00:00 2001 From: Philipp Arras <parras@mpa-garching.mpg.de> Date: Fri, 25 Oct 2019 15:37:15 +0200 Subject: [PATCH] Support fields as amplitudes in correlated field --- nifty5/library/correlated_fields.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/nifty5/library/correlated_fields.py b/nifty5/library/correlated_fields.py index 6910e974e..f06e2a987 100644 --- a/nifty5/library/correlated_fields.py +++ b/nifty5/library/correlated_fields.py @@ -19,13 +19,14 @@ from functools import reduce from operator import mul from ..domain_tuple import DomainTuple +from ..field import Field from ..operators.adder import Adder from ..operators.contraction_operator import ContractionOperator from ..operators.distributors import PowerDistributor from ..operators.harmonic_operators import HarmonicTransformOperator from ..operators.operator import Operator from ..operators.simple_linear_operators import VdotOperator, ducktape -from ..sugar import full +from ..sugar import full, makeOp def CorrelatedField(target, amplitude_operator, name='xi', codomain=None): @@ -68,15 +69,23 @@ def CorrelatedField(target, amplitude_operator, name='xi', codomain=None): codomain = tgt[0].get_default_codomain() h_space = codomain ht = HarmonicTransformOperator(h_space, target=tgt[0]) - p_space = amplitude_operator.target[0] + if isinstance(amplitude_operator, Operator): + p_space = amplitude_operator.target[0] + elif isinstance(amplitude_operator, Field): + p_space = amplitude_operator.domain[0] + else: + raise TypeError power_distributor = PowerDistributor(h_space, p_space) A = power_distributor(amplitude_operator) vol = h_space.scalar_dvol**-0.5 + xi = ducktape(h_space, None, name) # When doubling the resolution of `tgt` the value of the highest k-mode # will scale with a square root. `vol` cancels this effect such that the # same power spectrum can be used for the spaces with the same volume, # different resolutions and the same object in them. - return ht(vol*A*ducktape(h_space, None, name)) + if isinstance(amplitude_operator, Field): + return ht(makeOp(A)@xi).scale(vol) + return ht(A*xi).scale(vol) def MfCorrelatedField(target, amplitudes, name='xi'): -- GitLab