Commit c73b4467 authored by Philipp Arras's avatar Philipp Arras
Browse files

Add SlowPartialConstOperator

parent 5d607ffc
......@@ -245,6 +245,10 @@ class GaussianEnergy(EnergyOperator):
return res.add_metric(self._met)
return res
def __repr__(self):
dom = '()' if isinstance(self.domain, DomainTuple) else self.domain.keys()
return f'GaussianEnergy {dom}'
class PoissonianEnergy(EnergyOperator):
"""Computes likelihood Hamiltonians of expected count field constrained by
......
......@@ -297,7 +297,20 @@ class Operator(metaclass=NiftyMeta):
def _simplify_for_constant_input_nontrivial(self, c_inp):
from .simplify_for_const import SlowPartialConstantOperator
return None, SlowPartialConstantOperator(self, c_inp)
from ..multi_field import MultiField
try:
c_out = self.force(c_inp)
except KeyError:
c_out = None
if isinstance(c_out, MultiField):
dct = {}
for kk in set(c_inp.keys()) - set(self.domain.keys()):
if isinstance(self.target, MultiDomain) and kk in self.target.keys():
raise NotImplementedError
dct[kk] = c_inp[kk]
c_out = c_out.unite(MultiField.from_dict(dct))
return c_out, self @ SlowPartialConstantOperator(self.domain, c_inp.keys())
def ptw(self, op, *args, **kwargs):
return _OpChain.make((_FunctionApplier(self.target, op, *args, **kwargs), self))
......
......@@ -16,8 +16,11 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from .block_diagonal_operator import BlockDiagonalOperator
from .energy_operators import EnergyOperator
from .operator import Operator
from .scaling_operator import ScalingOperator
from .simple_linear_operators import NullOperator
......@@ -82,10 +85,30 @@ class ConstantOperator(Operator):
return f'{tgt} <- ConstantOperator <- {dom}'
class SlowPartialConstOperator(Operator):
pass
class SlowPartialConstantOperator(Operator):
def __init__(self, domain, constant_keys):
from ..sugar import makeDomain
if not isinstance(domain, MultiDomain):
raise TypeError
self._keys = set(constant_keys) & set(domain.keys())
if len(self._keys) == 0:
raise ValueError
self._domain = self._target = makeDomain(domain)
def apply(self, x):
self._check_input(x)
if x.jac is None:
return x
jac = {}
for kk, dd in self._domain.items():
fac = 1
if kk in self._keys:
fac = 0
jac[kk] = ScalingOperator(dd, fac)
return x.prepend_jac(BlockDiagonalOperator(x.jac.domain, jac))
def __repr__(self):
return f'SlowPartialConstantOperator ({self._keys})'
class ConstantEnergyOperator(EnergyOperator):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment