Commit f2c30d01 authored by Reimar Leike's avatar Reimar Leike Committed by Philipp Arras
Browse files

Added check for dtypes in ScalingOperators to avoid inconsistencies in the adjoint

parent 11bc2c37
......@@ -63,6 +63,7 @@ class ScalingOperator(EndomorphicOperator):
self._check_input(x, mode)
fct = self._factor
self._check_dtype(type(fct), x)
if fct == 1.:
return x
if fct == 0.:
......@@ -113,5 +114,20 @@ class ScalingOperator(EndomorphicOperator):
res = res.add_metric(met)
return res
@staticmethod
def _check_dtype(ftype, x):
# if factor is complex, input has to be complex as well
complex_types = [np.complex128, np.complex64, complex]
if ftype not in complex_types:
return
from .. import Field
if isinstance(x, Field):
dtypes = [x.val.dtype]
else:
dtypes = [v.dtype for v in x.values()]
for t in dtypes:
if t not in complex_types:
raise ValueError("Real input to ScalingOperator with complex factor. Try casting to Complex first using an adjoint Realizer")
def __repr__(self):
return "ScalingOperator ({})".format(self._factor)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment