Commit 82fde86c authored by Philipp Arras's avatar Philipp Arras
Browse files

Cleanup not working code

parent 90076c0f
Pipeline #75661 passed with stages
in 13 minutes and 3 seconds
......@@ -247,39 +247,6 @@ class GaussianEnergy(EnergyOperator):
return res.add_metric(self._met)
return res
def _simplify_for_constant_input_nontrivial(self, c_inp):
from .simplify_for_const import ConstantOperator
from ..multi_domain import MultiDomain
if not self._trivial_invcov:
raise NotImplementedError # FIXME
# No need to implement support for DomainTuple since this done by
# Operator.simplify_for_constant_input()
assert isinstance(self.domain, MultiDomain)
c_dom = {}
var_dom = {}
not_touched_dom = {}
for kk in self._domain.keys():
if kk in c_inp.domain.keys():
c_dom[kk] = self._domain[kk]
else:
var_dom[kk] = self._domain[kk]
for kk in set(c_inp.keys()) - set(self._domain.keys()):
not_touched_dom[kk] = c_inp.domain[kk]
var_dom = MultiDomain.make(var_dom)
c_dom = MultiDomain.make(c_dom)
not_touched_dom = MultiDomain.make(not_touched_dom)
c_mean = None if self._mean is None else self._mean.extract(c_dom)
var_mean = None if self._mean is None else self._mean.extract(var_dom)
c_op = ConstantOperator(c_dom,
GaussianEnergy(c_mean, None, c_inp.domain)(c_inp))
var_op = GaussianEnergy(var_mean, None, var_dom) #@ rest
newop = var_op + c_op
return c_inp.extract_part(not_touched_dom), newop
def __repr__(self):
dom = '()' if isinstance(self.domain, DomainTuple) else self.domain.keys()
return f'GaussianEnergy {dom}'
......
......@@ -18,6 +18,7 @@
import numpy as np
from .. import pointwise
from ..logger import logger
from ..multi_domain import MultiDomain
from ..utilities import NiftyMeta, indent
......@@ -274,8 +275,12 @@ class Operator(metaclass=NiftyMeta):
from .simplify_for_const import ConstantEnergyOperator, ConstantOperator
if c_inp is None:
return None, self
if isinstance(self.domain, MultiDomain):
assert isinstance(c_inp.domain, MultiDomain)
if set(c_inp.keys()) > set(self.domain.keys()):
raise ValueError
if c_inp.domain is self.domain:
if isinstance(self, EnergyOperator):
op = ConstantEnergyOperator(self.domain, self(c_inp))
......@@ -283,34 +288,17 @@ class Operator(metaclass=NiftyMeta):
op = ConstantOperator(self.domain, self(c_inp))
op = ConstantOperator(self.domain, self(c_inp))
return op(c_inp), op
if isinstance(self.domain, MultiDomain) and \
set(c_inp.keys()) > set(self.domain.keys()):
raise NotImplementedError('This branch is not tested yet')
op = ConstantOperator(self.domain, self.force(c_inp))
from ..sugar import makeField
unaffected = makeField({kk: vv for kk, vv in c_inp.items() if kk not in self.domain})
for kk in unaffected:
assert kk not in self.domain
assert kk not in self.target
return op.force(c_inp), op
if not isinstance(c_inp.domain, MultiDomain):
raise RuntimeError
return self._simplify_for_constant_input_nontrivial(c_inp)
def _simplify_for_constant_input_nontrivial(self, c_inp):
from .simplify_for_const import SlowPartialConstantOperator
from ..multi_field import MultiField
try:
c_out = self.force(c_inp)
except KeyError:
c_out = None
if isinstance(c_out, MultiField):
dct = {}
for kk in set(c_inp.keys()) - set(self.domain.keys()):
if isinstance(self.target, MultiDomain) and kk in self.target.keys():
raise NotImplementedError
dct[kk] = c_inp[kk]
c_out = c_out.unite(MultiField.from_dict(dct))
return c_out, self @ SlowPartialConstantOperator(self.domain, c_inp.keys())
s = ('SlowPartialConstantOperator used. You might want to consider',
' implementing `_simplify_for_constant_input_nontrivial()` for',
' this operator.')
logger.warning(s)
return None, self @ SlowPartialConstantOperator(self.domain, c_inp.keys())
def ptw(self, op, *args, **kwargs):
return _OpChain.make((_FunctionApplier(self.target, op, *args, **kwargs), self))
......
......@@ -16,7 +16,6 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from .block_diagonal_operator import BlockDiagonalOperator
from .energy_operators import EnergyOperator
from .operator import Operator
......@@ -90,21 +89,17 @@ class SlowPartialConstantOperator(Operator):
from ..sugar import makeDomain
if not isinstance(domain, MultiDomain):
raise TypeError
self._keys = set(constant_keys) & set(domain.keys())
if len(self._keys) == 0:
if set(constant_keys) > set(domain.keys()) or len(constant_keys) == 0:
raise ValueError
self._keys = set(constant_keys) & set(domain.keys())
self._domain = self._target = makeDomain(domain)
def apply(self, x):
self._check_input(x)
if x.jac is None:
return x
jac = {}
for kk, dd in self._domain.items():
fac = 1
if kk in self._keys:
fac = 0
jac[kk] = ScalingOperator(dd, fac)
jac = {kk: ScalingOperator(dd, 0 if kk in self._keys else 1)
for kk, dd in self._domain.items()}
return x.prepend_jac(BlockDiagonalOperator(x.jac.domain, jac))
def __repr__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment