Commit b8b3c138 authored by Theo Steininger's avatar Theo Steininger

Added domain_dtype and target_dtype attributes to FFTOperator. Removed dtype...

Added domain_dtype and target_dtype attributes to FFTOperator. Removed dtype from _times and _inverse_times.
parent 09ac8701
Pipeline #10125 passed with stages
in 16 minutes and 51 seconds
import numpy as np
import nifty.nifty_utilities as utilities
from nifty.spaces import RGSpace,\
GLSpace,\
......@@ -33,7 +35,8 @@ class FFTOperator(LinearOperator):
# ---Overwritten properties and methods---
def __init__(self, domain=(), target=None, module=None):
def __init__(self, domain=(), target=None, module=None,
domain_dtype=None, target_dtype=None):
self._domain = self._parse_domain(domain)
......@@ -69,7 +72,18 @@ class FFTOperator(LinearOperator):
self._backward_transformation = TransformationCache.create(
backward_class, self.target[0], self.domain[0], module=module)
def _times(self, x, spaces, dtype=None):
# Store the dtype information
if domain_dtype is None:
self.domain_dtype = None
else:
self.domain_dtype = np.dtype(domain_dtype)
if target_dtype is None:
self.target_dtype = None
else:
self.target_dtype = np.dtype(target_dtype)
def _times(self, x, spaces):
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
if spaces is None:
# this case means that x lives on only one space, which is
......@@ -87,12 +101,13 @@ class FFTOperator(LinearOperator):
result_domain = list(x.domain)
result_domain[spaces[0]] = self.target[0]
result_field = x.copy_empty(domain=result_domain, dtype=dtype)
result_field = x.copy_empty(domain=result_domain,
dtype=self.target_dtype)
result_field.set_val(new_val=new_val, copy=False)
return result_field
def _inverse_times(self, x, spaces, dtype=None):
def _inverse_times(self, x, spaces):
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
if spaces is None:
# this case means that x lives on only one space, which is
......@@ -110,7 +125,8 @@ class FFTOperator(LinearOperator):
result_domain = list(x.domain)
result_domain[spaces[0]] = self.domain[0]
result_field = x.copy_empty(domain=result_domain, dtype=dtype)
result_field = x.copy_empty(domain=result_domain,
dtype=self.domain_dtype)
result_field.set_val(new_val=new_val, copy=False)
return result_field
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment