Commit ba7cd26a authored by Philipp Arras's avatar Philipp Arras
Browse files

Fixups

parent c73b4467
...@@ -224,8 +224,8 @@ class SumOperator(LinearOperator): ...@@ -224,8 +224,8 @@ class SumOperator(LinearOperator):
fullop = op if fullop is None else fullop + op fullop = op if fullop is None else fullop + op
return None, fullop return None, fullop
from .operator import _ConstCollector from .simplify_for_const import ConstCollector
cc = _ConstCollector() cc = ConstCollector()
fullop = None fullop = None
for tf, to, n in zip(f, o, self._neg): for tf, to, n in zip(f, o, self._neg):
cc.add(tf, to.target) cc.add(tf, to.target)
......
...@@ -22,18 +22,18 @@ from ..common import setup_function, teardown_function ...@@ -22,18 +22,18 @@ from ..common import setup_function, teardown_function
def test_simplification(): def test_simplification():
from nifty6.operators.operator import _ConstantOperator from nifty6.operators.simplify_for_const import ConstantOperator
f1 = ift.Field.full(ift.RGSpace(10), 2.) f1 = ift.Field.full(ift.RGSpace(10), 2.)
op = ift.FFTOperator(f1.domain) op = ift.FFTOperator(f1.domain)
_, op2 = op.simplify_for_constant_input(f1) _, op2 = op.simplify_for_constant_input(f1)
assert_equal(isinstance(op2, _ConstantOperator), True) assert_equal(isinstance(op2, ConstantOperator), True)
assert_allclose(op(f1).val, op2(f1).val) assert_allclose(op(f1).val, op2(f1).val)
dom = {"a": ift.RGSpace(10)} dom = {"a": ift.RGSpace(10)}
f1 = ift.full(dom, 2.) f1 = ift.full(dom, 2.)
op = ift.FFTOperator(f1.domain["a"]).ducktape("a") op = ift.FFTOperator(f1.domain["a"]).ducktape("a")
_, op2 = op.simplify_for_constant_input(f1) _, op2 = op.simplify_for_constant_input(f1)
assert_equal(isinstance(op2, _ConstantOperator), True) assert_equal(isinstance(op2, ConstantOperator), True)
assert_allclose(op(f1).val, op2(f1).val) assert_allclose(op(f1).val, op2(f1).val)
dom = {"a": ift.RGSpace(10), "b": ift.RGSpace(5)} dom = {"a": ift.RGSpace(10), "b": ift.RGSpace(5)}
...@@ -45,7 +45,7 @@ def test_simplification(): ...@@ -45,7 +45,7 @@ def test_simplification():
op = (o1.ducktape("a").ducktape_left("a") + op = (o1.ducktape("a").ducktape_left("a") +
o2.ducktape("b").ducktape_left("b")) o2.ducktape("b").ducktape_left("b"))
_, op2 = op.simplify_for_constant_input(f2) _, op2 = op.simplify_for_constant_input(f2)
assert_equal(isinstance(op2._op1, _ConstantOperator), True) assert_equal(isinstance(op2._op1, ConstantOperator), True)
assert_allclose(op(f1)["a"].val, op2(f1)["a"].val) assert_allclose(op(f1)["a"].val, op2(f1)["a"].val)
assert_allclose(op(f1)["b"].val, op2(f1)["b"].val) assert_allclose(op(f1)["b"].val, op2(f1)["b"].val)
lin = ift.Linearization.make_var(ift.MultiField.full(op2.domain, 2.), True) lin = ift.Linearization.make_var(ift.MultiField.full(op2.domain, 2.), True)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment