From 18bae195a99d80372d453404f67e6088e6645dc2 Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Thu, 24 Jan 2019 10:34:20 +0100 Subject: [PATCH] don't allow Fields of objects --- nifty5/data_objects/distributed_do.py | 4 ++++ nifty5/data_objects/numpy_do.py | 4 ++++ test/test_field.py | 8 ++++---- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/nifty5/data_objects/distributed_do.py b/nifty5/data_objects/distributed_do.py index 0992df390..3bebd7fc5 100644 --- a/nifty5/data_objects/distributed_do.py +++ b/nifty5/data_objects/distributed_do.py @@ -387,10 +387,14 @@ def distaxis(arr): def from_local_data(shape, arr, distaxis=0): + if arr.dtype.kind not in "fciub": + raise TypeError return data_object(shape, arr, distaxis) def from_global_data(arr, sum_up=False, distaxis=0): + if arr.dtype.kind not in "fciub": + raise TypeError if sum_up: arr = np_allreduce_sum(arr) if distaxis == -1: diff --git a/nifty5/data_objects/numpy_do.py b/nifty5/data_objects/numpy_do.py index e582e3602..edd5476b8 100644 --- a/nifty5/data_objects/numpy_do.py +++ b/nifty5/data_objects/numpy_do.py @@ -97,10 +97,14 @@ def distaxis(arr): def from_local_data(shape, arr, distaxis=-1): if tuple(shape) != arr.shape: raise ValueError + if arr.dtype.kind not in "fciub": + raise TypeError return arr def from_global_data(arr, sum_up=False, distaxis=-1): + if arr.dtype.kind not in "fciub": + raise TypeError return arr diff --git a/test/test_field.py b/test/test_field.py index db9d41def..faf1c1ac7 100644 --- a/test/test_field.py +++ b/test/test_field.py @@ -203,14 +203,13 @@ def test_trivialities(): assert_equal(f1.clip(max=25).local_data, 25.) assert_equal(f1.local_data, f1.real.local_data) assert_equal(f1.local_data, (+f1).local_data) - print(f1) - print(str(f1)) f1 = ift.Field.full(s1, 27. + 3j) assert_equal(f1.one_over().local_data, (1./f1).local_data) assert_equal(f1.real.local_data, 27.) assert_equal(f1.imag.local_data, 3.) assert_equal(f1.sum(), f1.sum(0)) - assert_equal(f1.conjugate().local_data, ift.Field.full(s1, 27. - 3j).local_data) + assert_equal(f1.conjugate().local_data, + ift.Field.full(s1, 27. - 3j).local_data) f1 = ift.from_global_data(s1, np.arange(10)) # assert_equal(f1.min(), 0) # assert_equal(f1.max(), 9) @@ -350,4 +349,5 @@ def test_from_random(rtype, dtype): def test_field_of_objects(): arr = np.array(['x', 'y', 'z']) sp = ift.RGSpace(3) - f = ift.Field.from_global_data(sp, arr) + with assert_raises(TypeError): + f = ift.Field.from_global_data(sp, arr) -- GitLab