Commit 5a2e38b6 authored by Philipp Arras's avatar Philipp Arras
Browse files

Implement proper constant support 8/n

parent e3328be0
......@@ -58,7 +58,14 @@ class ChainOperator(LinearOperator):
fct = 1.
opsnew = []
lastdom = ops[-1].domain
dtype = None
for op in ops:
from .sampling_enabler import SamplingDtypeSetter
if isinstance(op, SamplingDtypeSetter) and isinstance(op._op, ScalingOperator):
if dtype is not None:
raise NotImplementedError
dtype = op._dtype
op = op._op
if (isinstance(op, ScalingOperator) and op._factor.imag == 0):
fct *= op._factor.real
else:
......@@ -72,7 +79,10 @@ class ChainOperator(LinearOperator):
break
if fct != 1 or len(opsnew) == 0:
# have to add the scaling operator at the end
opsnew.append(ScalingOperator(lastdom, fct))
op = ScalingOperator(lastdom, fct)
if dtype is not None:
op = SamplingDtypeSetter(op, dtype)
opsnew.append(op)
ops = opsnew
# combine DiagonalOperators where possible
opsnew = []
......
......@@ -35,7 +35,6 @@ def test_gaussian(field):
ift.extra.check_operator(energy, field)
@pytest.mark.skip()
def test_ScaledEnergy(field):
icov = ift.ScalingOperator(field.domain, 1.2)
energy = ift.GaussianEnergy(inverse_covariance=icov, sampling_dtype=np.float64)
......@@ -48,6 +47,7 @@ def test_ScaledEnergy(field):
res2 = met2(field)/0.3
ift.extra.assert_allclose(res1, res2, 0, 1e-12)
met1.draw_sample()
print(met2)
met2.draw_sample()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment