Commit 18bae195 authored by Martin Reinecke's avatar Martin Reinecke

don't allow Fields of objects

parent adfcaeb6
......@@ -387,10 +387,14 @@ def distaxis(arr):
def from_local_data(shape, arr, distaxis=0):
if arr.dtype.kind not in "fciub":
raise TypeError
return data_object(shape, arr, distaxis)
def from_global_data(arr, sum_up=False, distaxis=0):
if arr.dtype.kind not in "fciub":
raise TypeError
if sum_up:
arr = np_allreduce_sum(arr)
if distaxis == -1:
......
......@@ -97,10 +97,14 @@ def distaxis(arr):
def from_local_data(shape, arr, distaxis=-1):
if tuple(shape) != arr.shape:
raise ValueError
if arr.dtype.kind not in "fciub":
raise TypeError
return arr
def from_global_data(arr, sum_up=False, distaxis=-1):
if arr.dtype.kind not in "fciub":
raise TypeError
return arr
......
......@@ -203,14 +203,13 @@ def test_trivialities():
assert_equal(f1.clip(max=25).local_data, 25.)
assert_equal(f1.local_data, f1.real.local_data)
assert_equal(f1.local_data, (+f1).local_data)
print(f1)
print(str(f1))
f1 = ift.Field.full(s1, 27. + 3j)
assert_equal(f1.one_over().local_data, (1./f1).local_data)
assert_equal(f1.real.local_data, 27.)
assert_equal(f1.imag.local_data, 3.)
assert_equal(f1.sum(), f1.sum(0))
assert_equal(f1.conjugate().local_data, ift.Field.full(s1, 27. - 3j).local_data)
assert_equal(f1.conjugate().local_data,
ift.Field.full(s1, 27. - 3j).local_data)
f1 = ift.from_global_data(s1, np.arange(10))
# assert_equal(f1.min(), 0)
# assert_equal(f1.max(), 9)
......@@ -350,4 +349,5 @@ def test_from_random(rtype, dtype):
def test_field_of_objects():
arr = np.array(['x', 'y', 'z'])
sp = ift.RGSpace(3)
f = ift.Field.from_global_data(sp, arr)
with assert_raises(TypeError):
f = ift.Field.from_global_data(sp, arr)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment