diff --git a/nifty5/data_objects/distributed_do.py b/nifty5/data_objects/distributed_do.py index 0992df390356bde45e537e78ce51cf1228ae8128..3bebd7fc58774f637b60467a937629fc8758b92a 100644 --- a/nifty5/data_objects/distributed_do.py +++ b/nifty5/data_objects/distributed_do.py @@ -387,10 +387,14 @@ def distaxis(arr): def from_local_data(shape, arr, distaxis=0): + if arr.dtype.kind not in "fciub": + raise TypeError return data_object(shape, arr, distaxis) def from_global_data(arr, sum_up=False, distaxis=0): + if arr.dtype.kind not in "fciub": + raise TypeError if sum_up: arr = np_allreduce_sum(arr) if distaxis == -1: diff --git a/nifty5/data_objects/numpy_do.py b/nifty5/data_objects/numpy_do.py index e582e36023bb914a0195d5e6deb9379a62a68548..edd5476b8d2eef85d5cf26ac893580a83d5ba308 100644 --- a/nifty5/data_objects/numpy_do.py +++ b/nifty5/data_objects/numpy_do.py @@ -97,10 +97,14 @@ def distaxis(arr): def from_local_data(shape, arr, distaxis=-1): if tuple(shape) != arr.shape: raise ValueError + if arr.dtype.kind not in "fciub": + raise TypeError return arr def from_global_data(arr, sum_up=False, distaxis=-1): + if arr.dtype.kind not in "fciub": + raise TypeError return arr diff --git a/test/test_field.py b/test/test_field.py index db9d41defb0bcea8086eb43acedd67ff15780307..faf1c1ac732f67efd2daad6dbd5291b1122b5f5a 100644 --- a/test/test_field.py +++ b/test/test_field.py @@ -203,14 +203,13 @@ def test_trivialities(): assert_equal(f1.clip(max=25).local_data, 25.) assert_equal(f1.local_data, f1.real.local_data) assert_equal(f1.local_data, (+f1).local_data) - print(f1) - print(str(f1)) f1 = ift.Field.full(s1, 27. + 3j) assert_equal(f1.one_over().local_data, (1./f1).local_data) assert_equal(f1.real.local_data, 27.) assert_equal(f1.imag.local_data, 3.) assert_equal(f1.sum(), f1.sum(0)) - assert_equal(f1.conjugate().local_data, ift.Field.full(s1, 27. - 3j).local_data) + assert_equal(f1.conjugate().local_data, + ift.Field.full(s1, 27. - 3j).local_data) f1 = ift.from_global_data(s1, np.arange(10)) # assert_equal(f1.min(), 0) # assert_equal(f1.max(), 9) @@ -350,4 +349,5 @@ def test_from_random(rtype, dtype): def test_field_of_objects(): arr = np.array(['x', 'y', 'z']) sp = ift.RGSpace(3) - f = ift.Field.from_global_data(sp, arr) + with assert_raises(TypeError): + f = ift.Field.from_global_data(sp, arr)