From 374675e9a35a80f2bd4f3ad8fc78cbf5fd93b6a2 Mon Sep 17 00:00:00 2001
From: Philipp Arras <parras@mpa-garching.mpg.de>
Date: Fri, 25 Oct 2019 15:37:15 +0200
Subject: [PATCH] Support fields as amplitudes in correlated field

---
 nifty5/library/correlated_fields.py | 15 ++++++++++++---
 1 file changed, 12 insertions(+), 3 deletions(-)

diff --git a/nifty5/library/correlated_fields.py b/nifty5/library/correlated_fields.py
index 6910e974e..f06e2a987 100644
--- a/nifty5/library/correlated_fields.py
+++ b/nifty5/library/correlated_fields.py
@@ -19,13 +19,14 @@ from functools import reduce
 from operator import mul
 
 from ..domain_tuple import DomainTuple
+from ..field import Field
 from ..operators.adder import Adder
 from ..operators.contraction_operator import ContractionOperator
 from ..operators.distributors import PowerDistributor
 from ..operators.harmonic_operators import HarmonicTransformOperator
 from ..operators.operator import Operator
 from ..operators.simple_linear_operators import VdotOperator, ducktape
-from ..sugar import full
+from ..sugar import full, makeOp
 
 
 def CorrelatedField(target, amplitude_operator, name='xi', codomain=None):
@@ -68,15 +69,23 @@ def CorrelatedField(target, amplitude_operator, name='xi', codomain=None):
         codomain = tgt[0].get_default_codomain()
     h_space = codomain
     ht = HarmonicTransformOperator(h_space, target=tgt[0])
-    p_space = amplitude_operator.target[0]
+    if isinstance(amplitude_operator, Operator):
+        p_space = amplitude_operator.target[0]
+    elif isinstance(amplitude_operator, Field):
+        p_space = amplitude_operator.domain[0]
+    else:
+        raise TypeError
     power_distributor = PowerDistributor(h_space, p_space)
     A = power_distributor(amplitude_operator)
     vol = h_space.scalar_dvol**-0.5
+    xi = ducktape(h_space, None, name)
     # When doubling the resolution of `tgt` the value of the highest k-mode
     # will scale with a square root. `vol` cancels this effect such that the
     # same power spectrum can be used for the spaces with the same volume,
     # different resolutions and the same object in them.
-    return ht(vol*A*ducktape(h_space, None, name))
+    if isinstance(amplitude_operator, Field):
+        return ht(makeOp(A)@xi).scale(vol)
+    return ht(A*xi).scale(vol)
 
 
 def MfCorrelatedField(target, amplitudes, name='xi'):
-- 
GitLab