Commit b3d458af authored by Philipp Arras's avatar Philipp Arras
Browse files

Add simplify for GaussianEnergy

parent ba7cd26a
......@@ -223,9 +223,11 @@ class GaussianEnergy(EnergyOperator):
if inverse_covariance is None:
self._op = Squared2NormOperator(self._domain).scale(0.5)
self._met = ScalingOperator(self._domain, 1)
self._trivial_invcov = True
else:
self._op = QuadraticFormOperator(inverse_covariance)
self._met = inverse_covariance
self._trivial_invcov = False
if sampling_dtype is not None:
self._met = SamplingDtypeSetter(self._met, sampling_dtype)
......@@ -245,6 +247,39 @@ class GaussianEnergy(EnergyOperator):
return res.add_metric(self._met)
return res
def _simplify_for_constant_input_nontrivial(self, c_inp):
from .simplify_for_const import ConstantOperator
from ..multi_domain import MultiDomain
if not self._trivial_invcov:
raise NotImplementedError # FIXME
# No need to implement support for DomainTuple since this done by
# Operator.simplify_for_constant_input()
assert isinstance(self.domain, MultiDomain)
c_dom = {}
var_dom = {}
not_touched_dom = {}
for kk in self._domain.keys():
if kk in c_inp.domain.keys():
c_dom[kk] = self._domain[kk]
else:
var_dom[kk] = self._domain[kk]
for kk in set(c_inp.keys()) - set(self._domain.keys()):
not_touched_dom[kk] = c_inp.domain[kk]
var_dom = MultiDomain.make(var_dom)
c_dom = MultiDomain.make(c_dom)
not_touched_dom = MultiDomain.make(not_touched_dom)
c_mean = None if self._mean is None else self._mean.extract(c_dom)
var_mean = None if self._mean is None else self._mean.extract(var_dom)
c_op = ConstantOperator(c_dom,
GaussianEnergy(c_mean, None, c_inp.domain)(c_inp))
var_op = GaussianEnergy(var_mean, None, var_dom) #@ rest
newop = var_op + c_op
return c_inp.extract_part(not_touched_dom), newop
def __repr__(self):
dom = '()' if isinstance(self.domain, DomainTuple) else self.domain.keys()
return f'GaussianEnergy {dom}'
......
......@@ -349,6 +349,11 @@ class NullOperator(LinearOperator):
self._check_input(x, mode)
return self._nullfield(self._tgt(mode))
def __repr__(self):
dom = self.domain.keys() if isinstance(self.domain, MultiDomain) else '()'
tgt = self.target.keys() if isinstance(self.target, MultiDomain) else '()'
return f'{tgt} <- NullOperator <- {dom}'
class PartialExtractor(LinearOperator):
def __init__(self, domain, target):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment