Skip to content
Snippets Groups Projects
Commit 5a2e38b6 authored by Philipp Arras's avatar Philipp Arras
Browse files

Implement proper constant support 8/n

parent e3328be0
Branches
Tags
1 merge request!545Proper constants
Pipeline #77011 passed
...@@ -58,7 +58,14 @@ class ChainOperator(LinearOperator): ...@@ -58,7 +58,14 @@ class ChainOperator(LinearOperator):
fct = 1. fct = 1.
opsnew = [] opsnew = []
lastdom = ops[-1].domain lastdom = ops[-1].domain
dtype = None
for op in ops: for op in ops:
from .sampling_enabler import SamplingDtypeSetter
if isinstance(op, SamplingDtypeSetter) and isinstance(op._op, ScalingOperator):
if dtype is not None:
raise NotImplementedError
dtype = op._dtype
op = op._op
if (isinstance(op, ScalingOperator) and op._factor.imag == 0): if (isinstance(op, ScalingOperator) and op._factor.imag == 0):
fct *= op._factor.real fct *= op._factor.real
else: else:
...@@ -72,7 +79,10 @@ class ChainOperator(LinearOperator): ...@@ -72,7 +79,10 @@ class ChainOperator(LinearOperator):
break break
if fct != 1 or len(opsnew) == 0: if fct != 1 or len(opsnew) == 0:
# have to add the scaling operator at the end # have to add the scaling operator at the end
opsnew.append(ScalingOperator(lastdom, fct)) op = ScalingOperator(lastdom, fct)
if dtype is not None:
op = SamplingDtypeSetter(op, dtype)
opsnew.append(op)
ops = opsnew ops = opsnew
# combine DiagonalOperators where possible # combine DiagonalOperators where possible
opsnew = [] opsnew = []
......
...@@ -35,7 +35,6 @@ def test_gaussian(field): ...@@ -35,7 +35,6 @@ def test_gaussian(field):
ift.extra.check_operator(energy, field) ift.extra.check_operator(energy, field)
@pytest.mark.skip()
def test_ScaledEnergy(field): def test_ScaledEnergy(field):
icov = ift.ScalingOperator(field.domain, 1.2) icov = ift.ScalingOperator(field.domain, 1.2)
energy = ift.GaussianEnergy(inverse_covariance=icov, sampling_dtype=np.float64) energy = ift.GaussianEnergy(inverse_covariance=icov, sampling_dtype=np.float64)
...@@ -48,6 +47,7 @@ def test_ScaledEnergy(field): ...@@ -48,6 +47,7 @@ def test_ScaledEnergy(field):
res2 = met2(field)/0.3 res2 = met2(field)/0.3
ift.extra.assert_allclose(res1, res2, 0, 1e-12) ift.extra.assert_allclose(res1, res2, 0, 1e-12)
met1.draw_sample() met1.draw_sample()
print(met2)
met2.draw_sample() met2.draw_sample()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment