Commit a2ccbfa2 authored by Philipp Arras's avatar Philipp Arras
Browse files

_KeyModifier -> PrependKey

parent 1520a2fa
......@@ -526,22 +526,3 @@ def _tableentries(redchisq, scmean, ndof, keylen):
out += f"{ndof[kk]:>11}"
out += "\n"
return out[:-1]
class _KeyModifier(LinearOperator):
def __init__(self, domain, pre):
if not isinstance(domain, MultiDomain):
raise ValueError
from .sugar import makeDomain
self._domain = makeDomain(domain)
self._pre = str(pre)
target = {self._pre+k: domain[k] for k in domain.keys()}
self._target = makeDomain(MultiDomain.make(target))
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
res = {self._pre+k:x[k] for k in self._domain.keys()}
else:
res = {k:x[self._pre+k] for k in self._domain.keys()}
return MultiField.from_dict(res, domain=self._tgt(mode))
......@@ -519,17 +519,17 @@ class _OpSum(Operator):
return res
def get_transformation(self):
from .simple_linear_operators import PrependKey
tr1 = self._op1.get_transformation()
tr2 = self._op2.get_transformation()
if tr1 is None or tr2 is None:
return None
from ..extra import _KeyModifier
dtype, trafo = {}, None
for i, lh in enumerate([self._op1, self._op2]):
dtp, tr = lh.get_transformation()
if isinstance(tr.target, MultiDomain):
dtype.update({str(i)+d:dtp[d] for d in dtp.keys()})
tr = _KeyModifier(tr.target, str(i)) @ tr
dtype.update({str(i)+d: dtp[d] for d in dtp.keys()})
tr = PrependKey(tr.target, str(i)) @ tr
trafo = tr if trafo is None else trafo+tr
else:
dtype[str(i)] = dtp
......
......@@ -380,3 +380,30 @@ class PartialExtractor(LinearOperator):
def __repr__(self):
return f'{self.target.keys()} <- {self.domain.keys()}'
class PrependKey(LinearOperator):
"""Prepend a string to all keys of a MultiDomain.
Parameters
----------
domain : MultiDomain
pre : str
"""
def __init__(self, domain, pre):
if not isinstance(domain, MultiDomain):
raise ValueError
from ..sugar import makeDomain
self._domain = makeDomain(domain)
self._pre = str(pre)
target = {self._pre+k: domain[k] for k in domain.keys()}
self._target = makeDomain(MultiDomain.make(target))
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
res = {self._pre+k:x[k] for k in self._domain.keys()}
else:
res = {k:x[self._pre+k] for k in self._domain.keys()}
return MultiField.from_dict(res, domain=self._tgt(mode))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment