Commit a2ccbfa2 authored by Philipp Arras's avatar Philipp Arras
Browse files

_KeyModifier -> PrependKey

parent 1520a2fa
...@@ -526,22 +526,3 @@ def _tableentries(redchisq, scmean, ndof, keylen): ...@@ -526,22 +526,3 @@ def _tableentries(redchisq, scmean, ndof, keylen):
out += f"{ndof[kk]:>11}" out += f"{ndof[kk]:>11}"
out += "\n" out += "\n"
return out[:-1] return out[:-1]
class _KeyModifier(LinearOperator):
def __init__(self, domain, pre):
if not isinstance(domain, MultiDomain):
raise ValueError
from .sugar import makeDomain
self._domain = makeDomain(domain)
self._pre = str(pre)
target = {self._pre+k: domain[k] for k in domain.keys()}
self._target = makeDomain(MultiDomain.make(target))
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
res = {self._pre+k:x[k] for k in self._domain.keys()}
else:
res = {k:x[self._pre+k] for k in self._domain.keys()}
return MultiField.from_dict(res, domain=self._tgt(mode))
...@@ -519,17 +519,17 @@ class _OpSum(Operator): ...@@ -519,17 +519,17 @@ class _OpSum(Operator):
return res return res
def get_transformation(self): def get_transformation(self):
from .simple_linear_operators import PrependKey
tr1 = self._op1.get_transformation() tr1 = self._op1.get_transformation()
tr2 = self._op2.get_transformation() tr2 = self._op2.get_transformation()
if tr1 is None or tr2 is None: if tr1 is None or tr2 is None:
return None return None
from ..extra import _KeyModifier
dtype, trafo = {}, None dtype, trafo = {}, None
for i, lh in enumerate([self._op1, self._op2]): for i, lh in enumerate([self._op1, self._op2]):
dtp, tr = lh.get_transformation() dtp, tr = lh.get_transformation()
if isinstance(tr.target, MultiDomain): if isinstance(tr.target, MultiDomain):
dtype.update({str(i)+d:dtp[d] for d in dtp.keys()}) dtype.update({str(i)+d: dtp[d] for d in dtp.keys()})
tr = _KeyModifier(tr.target, str(i)) @ tr tr = PrependKey(tr.target, str(i)) @ tr
trafo = tr if trafo is None else trafo+tr trafo = tr if trafo is None else trafo+tr
else: else:
dtype[str(i)] = dtp dtype[str(i)] = dtp
......
...@@ -380,3 +380,30 @@ class PartialExtractor(LinearOperator): ...@@ -380,3 +380,30 @@ class PartialExtractor(LinearOperator):
def __repr__(self): def __repr__(self):
return f'{self.target.keys()} <- {self.domain.keys()}' return f'{self.target.keys()} <- {self.domain.keys()}'
class PrependKey(LinearOperator):
"""Prepend a string to all keys of a MultiDomain.
Parameters
----------
domain : MultiDomain
pre : str
"""
def __init__(self, domain, pre):
if not isinstance(domain, MultiDomain):
raise ValueError
from ..sugar import makeDomain
self._domain = makeDomain(domain)
self._pre = str(pre)
target = {self._pre+k: domain[k] for k in domain.keys()}
self._target = makeDomain(MultiDomain.make(target))
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
res = {self._pre+k:x[k] for k in self._domain.keys()}
else:
res = {k:x[self._pre+k] for k in self._domain.keys()}
return MultiField.from_dict(res, domain=self._tgt(mode))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment