Commit b8b3c138 authored by Theo Steininger's avatar Theo Steininger
Browse files

Added domain_dtype and target_dtype attributes to FFTOperator. Removed dtype...

Added domain_dtype and target_dtype attributes to FFTOperator. Removed dtype from _times and _inverse_times.
parent 09ac8701
Pipeline #10125 passed with stages
in 16 minutes and 51 seconds
import numpy as np
import nifty.nifty_utilities as utilities import nifty.nifty_utilities as utilities
from nifty.spaces import RGSpace,\ from nifty.spaces import RGSpace,\
GLSpace,\ GLSpace,\
...@@ -33,7 +35,8 @@ class FFTOperator(LinearOperator): ...@@ -33,7 +35,8 @@ class FFTOperator(LinearOperator):
# ---Overwritten properties and methods--- # ---Overwritten properties and methods---
def __init__(self, domain=(), target=None, module=None): def __init__(self, domain=(), target=None, module=None,
domain_dtype=None, target_dtype=None):
self._domain = self._parse_domain(domain) self._domain = self._parse_domain(domain)
...@@ -69,7 +72,18 @@ class FFTOperator(LinearOperator): ...@@ -69,7 +72,18 @@ class FFTOperator(LinearOperator):
self._backward_transformation = TransformationCache.create( self._backward_transformation = TransformationCache.create(
backward_class, self.target[0], self.domain[0], module=module) backward_class, self.target[0], self.domain[0], module=module)
def _times(self, x, spaces, dtype=None): # Store the dtype information
if domain_dtype is None:
self.domain_dtype = None
else:
self.domain_dtype = np.dtype(domain_dtype)
if target_dtype is None:
self.target_dtype = None
else:
self.target_dtype = np.dtype(target_dtype)
def _times(self, x, spaces):
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain)) spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
if spaces is None: if spaces is None:
# this case means that x lives on only one space, which is # this case means that x lives on only one space, which is
...@@ -87,12 +101,13 @@ class FFTOperator(LinearOperator): ...@@ -87,12 +101,13 @@ class FFTOperator(LinearOperator):
result_domain = list(x.domain) result_domain = list(x.domain)
result_domain[spaces[0]] = self.target[0] result_domain[spaces[0]] = self.target[0]
result_field = x.copy_empty(domain=result_domain, dtype=dtype) result_field = x.copy_empty(domain=result_domain,
dtype=self.target_dtype)
result_field.set_val(new_val=new_val, copy=False) result_field.set_val(new_val=new_val, copy=False)
return result_field return result_field
def _inverse_times(self, x, spaces, dtype=None): def _inverse_times(self, x, spaces):
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain)) spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
if spaces is None: if spaces is None:
# this case means that x lives on only one space, which is # this case means that x lives on only one space, which is
...@@ -110,7 +125,8 @@ class FFTOperator(LinearOperator): ...@@ -110,7 +125,8 @@ class FFTOperator(LinearOperator):
result_domain = list(x.domain) result_domain = list(x.domain)
result_domain[spaces[0]] = self.domain[0] result_domain[spaces[0]] = self.domain[0]
result_field = x.copy_empty(domain=result_domain, dtype=dtype) result_field = x.copy_empty(domain=result_domain,
dtype=self.domain_dtype)
result_field.set_val(new_val=new_val, copy=False) result_field.set_val(new_val=new_val, copy=False)
return result_field return result_field
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment