Commit ca8ec3cc authored by Philipp Arras's avatar Philipp Arras

Implement proper constant support 3/n

parent 2ffc1226
......@@ -184,7 +184,8 @@ class MetricGaussianKL(Energy):
_, ham_sampling = hamiltonian.simplify_for_constant_input(cstpos)
else:
ham_sampling = hamiltonian
met = ham_sampling(Linearization.make_var(mean.extract(ham_sampling.domain), True)).metric
lin = Linearization.make_var(mean.extract(ham_sampling.domain), True)
met = ham_sampling(lin).metric
if napprox >= 1:
met._approximation = makeOp(approximation2endo(met, napprox))
local_samples = []
......
......@@ -484,11 +484,9 @@ class StandardHamiltonian(EnergyOperator):
`<https://arxiv.org/abs/1812.04403>`_
"""
def __init__(self, lh, ic_samp=None, _c_inp=None, prior_dtype=np.float64):
def __init__(self, lh, ic_samp=None, prior_dtype=np.float64):
self._lh = lh
self._prior = GaussianEnergy(domain=lh.domain, sampling_dtype=prior_dtype)
if _c_inp is not None:
_, self._prior = self._prior.simplify_for_constant_input(_c_inp)
self._ic_samp = ic_samp
self._domain = lh.domain
......@@ -504,9 +502,9 @@ class StandardHamiltonian(EnergyOperator):
subs += '\nPrior:\n{}'.format(self._prior)
return 'StandardHamiltonian:\n' + utilities.indent(subs)
# def _simplify_for_constant_input_nontrivial(self, c_inp):
# out, lh1 = self._lh.simplify_for_constant_input(c_inp)
# return out, StandardHamiltonian(lh1, self._ic_samp, _c_inp=c_inp)
def _simplify_for_constant_input_nontrivial(self, c_inp):
out, lh1 = self._lh.simplify_for_constant_input(c_inp)
return out, StandardHamiltonian(lh1, self._ic_samp)
class AveragedEnergy(EnergyOperator):
......
......@@ -285,7 +285,7 @@ class Operator(metaclass=NiftyMeta):
# subdomain of self._domain
if isinstance(self.domain, MultiDomain):
assert isinstance(dom, MultiDomain)
if set(c_inp.keys()) > set(self.domain.keys()):
if not set(c_inp.keys()) <= set(self.domain.keys()):
raise ValueError
if dom is self.domain:
......
......@@ -15,8 +15,6 @@
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from .diagonal_operator import DiagonalOperator
from .endomorphic_operator import EndomorphicOperator
from .linear_operator import LinearOperator
......@@ -56,11 +54,15 @@ class SandwichOperator(EndomorphicOperator):
old_cheese = cheese
cheese = old_cheese._cheese
bun = old_cheese._bun @ bun
if not isinstance(bun, LinearOperator):
raise TypeError("bun must be a linear operator")
if isinstance(bun, ScalingOperator):
return cheese.scale(bun._factor**2)
if cheese is not None and not isinstance(cheese, LinearOperator):
raise TypeError("cheese must be a linear operator or None")
if cheese is None:
# FIXME Sampling dtype not clear in this case
cheese = ScalingOperator(bun.target, 1.)
op = bun.adjoint(bun)
else:
......
......@@ -334,9 +334,9 @@ class NullOperator(LinearOperator):
@staticmethod
def _nullfield(dom):
if isinstance(dom, DomainTuple):
return Field(dom, 0)
return Field(dom, 0.)
else:
return MultiField.full(dom, 0)
return MultiField.full(dom, 0.)
def apply(self, x, mode):
self._check_input(x, mode)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment