Commit e3328be0 authored by Philipp Arras's avatar Philipp Arras
Browse files

Implement proper constant support 7/n

parent 092bf7fd
...@@ -275,6 +275,7 @@ class Operator(metaclass=NiftyMeta): ...@@ -275,6 +275,7 @@ class Operator(metaclass=NiftyMeta):
from .simplify_for_const import ConstantEnergyOperator, ConstantOperator from .simplify_for_const import ConstantEnergyOperator, ConstantOperator
from ..multi_field import MultiField from ..multi_field import MultiField
from ..domain_tuple import DomainTuple from ..domain_tuple import DomainTuple
from ..sugar import makeDomain
if c_inp is None or (isinstance(c_inp, MultiField) and len(c_inp.keys()) == 0): if c_inp is None or (isinstance(c_inp, MultiField) and len(c_inp.keys()) == 0):
return None, self return None, self
dom = c_inp.domain dom = c_inp.domain
...@@ -298,7 +299,17 @@ class Operator(metaclass=NiftyMeta): ...@@ -298,7 +299,17 @@ class Operator(metaclass=NiftyMeta):
return None, op return None, op
if not isinstance(dom, MultiDomain): if not isinstance(dom, MultiDomain):
raise RuntimeError raise RuntimeError
return self._simplify_for_constant_input_nontrivial(c_inp) c_out, op = self._simplify_for_constant_input_nontrivial(c_inp)
vardom = makeDomain({kk: vv for kk, vv in self.domain.items()
if kk not in c_inp.keys()})
assert op.domain is vardom
assert op.target is self.target
assert isinstance(op, Operator)
if c_out is not None:
assert isinstance(c_out, MultiField)
assert len(set(c_out.keys()) & self.domain.keys()) == 0
assert set(c_out.keys()) <= set(c_inp.keys())
return c_out, op
def _simplify_for_constant_input_nontrivial(self, c_inp): def _simplify_for_constant_input_nontrivial(self, c_inp):
from .simplify_for_const import InsertionOperator from .simplify_for_const import InsertionOperator
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment