From f7ee0e25e87e4ce1c673000f86f26085e2d336bd Mon Sep 17 00:00:00 2001
From: Reimar Leike <reimar@mpa-garhcing.mpg.de>
Date: Tue, 16 Jun 2020 10:41:36 +0200
Subject: [PATCH] Test complex samples

---
 .../test_sample_dtype_consistency.py          | 57 +++++++++++++++++++
 1 file changed, 57 insertions(+)
 create mode 100644 test/test_operators/test_sample_dtype_consistency.py

diff --git a/test/test_operators/test_sample_dtype_consistency.py b/test/test_operators/test_sample_dtype_consistency.py
new file mode 100644
index 000000000..d02ab4d16
--- /dev/null
+++ b/test/test_operators/test_sample_dtype_consistency.py
@@ -0,0 +1,57 @@
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+#
+# Copyright(C) 2013-2020 Max-Planck-Society
+#
+# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
+
+import numpy as np
+import pytest
+
+import nifty6 as ift
+
+from ..common import list2fixture, setup_function, teardown_function
+from scipy.stats import norm
+
+
+Nsamp = 20000
+np.random.seed(42)
+
+def _to_array(d):
+    if isinstance(d, np.ndarray):
+        return d
+    assert isinstance(d, dict)
+    return np.concatenate(list(d.values()))
+
+
+def test_GaussianEnergy():
+    sp = ift.UnstructuredDomain(Nsamp)
+    S = ift.ScalingOperator(sp, 1.)
+    samp = S.draw_sample_with_dtype(dtype=np.complex128)
+    real_std = np.std(samp.val.real)
+    imag_std = np.std(samp.val.imag)
+    np.testing.assert_allclose(real_std, imag_std, 
+            atol=5./np.sqrt(Nsamp))
+    sp1 = ift.UnstructuredDomain(1)
+    mean = ift.full(sp1, 0.)
+    real_icov = ift.ScalingOperator(sp1, real_std**(-1))
+    imag_icov = ift.ScalingOperator(sp1, imag_std**(-1))
+    real_energy = ift.GaussianEnergy(mean, inverse_covariance=real_icov)
+    imag_energy = ift.GaussianEnergy(mean, inverse_covariance=imag_icov)
+    icov = ift.ScalingOperator(sp1, 1.)
+    complex_energy = ift.GaussianEnergy(mean+0.j, inverse_covariance=icov)
+    for i in range(min(10, Nsamp)):
+        fld = ift.full(sp1, samp.val[i])
+        val1 = (real_energy(fld.real) + imag_energy(fld.imag)).val
+        val2 = complex_energy(fld).val
+        np.testing.assert_allclose(val1, val2, atol=10./np.sqrt(Nsamp))
-- 
GitLab