Commit e4dcb089 authored by Theo Steininger's avatar Theo Steininger
Browse files

Fixed dtype bug in RG-FFT:

If the dtype of the codomain != dtype of result field, information is lost due to conversion complex -> float.
parent 2cc5aff0
Pipeline #9527 passed with stages
in 40 minutes and 28 seconds
...@@ -286,8 +286,9 @@ class FFTW(Transform): ...@@ -286,8 +286,9 @@ class FFTW(Transform):
try: try:
# Create return object and insert results inplace # Create return object and insert results inplace
result_dtype = np.result_type(np.complex, self.codomain.dtype)
return_val = val.copy_empty(global_shape=val.shape, return_val = val.copy_empty(global_shape=val.shape,
dtype=self.codomain.dtype) dtype=result_dtype)
return_val.set_local_data(data=local_result, copy=False) return_val.set_local_data(data=local_result, copy=False)
except(AttributeError): except(AttributeError):
return_val = local_result return_val = local_result
...@@ -315,8 +316,9 @@ class FFTW(Transform): ...@@ -315,8 +316,9 @@ class FFTW(Transform):
) )
local_offset_Q = bool(local_offset_list[val.distributor.comm.rank] % 2) local_offset_Q = bool(local_offset_list[val.distributor.comm.rank] % 2)
result_dtype = np.result_type(np.complex, self.codomain.dtype)
return_val = val.copy_empty(global_shape=val.shape, return_val = val.copy_empty(global_shape=val.shape,
dtype=self.codomain.dtype) dtype=result_dtype)
# Extract local data # Extract local data
local_val = val.get_local_data(copy=False) local_val = val.get_local_data(copy=False)
...@@ -337,7 +339,7 @@ class FFTW(Transform): ...@@ -337,7 +339,7 @@ class FFTW(Transform):
if temp_val is None: if temp_val is None:
temp_val = np.empty_like( temp_val = np.empty_like(
local_val, local_val,
dtype=self.codomain.dtype dtype=result_dtype
) )
inp = local_val[slice_list] inp = local_val[slice_list]
...@@ -620,11 +622,12 @@ class GFFT(Transform): ...@@ -620,11 +622,12 @@ class GFFT(Transform):
else: else:
return_val[slice_list] = inp return_val[slice_list] = inp
result_dtype = np.result_type(np.complex, self.codomain.dtype)
if isinstance(val, distributed_data_object): if isinstance(val, distributed_data_object):
new_val = val.copy_empty(dtype=self.codomain.dtype) new_val = val.copy_empty(dtype=result_dtype)
new_val.set_full_data(return_val, copy=False) new_val.set_full_data(return_val, copy=False)
return_val = new_val return_val = new_val
else: else:
return_val = return_val.astype(self.codomain.dtype, copy=False) return_val = return_val.astype(result_type, copy=False)
return return_val return return_val
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment