Commit ca8ec3cc by Philipp Arras

### Implement proper constant support 3/n

parent 2ffc1226
 ... ... @@ -184,7 +184,8 @@ class MetricGaussianKL(Energy): _, ham_sampling = hamiltonian.simplify_for_constant_input(cstpos) else: ham_sampling = hamiltonian met = ham_sampling(Linearization.make_var(mean.extract(ham_sampling.domain), True)).metric lin = Linearization.make_var(mean.extract(ham_sampling.domain), True) met = ham_sampling(lin).metric if napprox >= 1: met._approximation = makeOp(approximation2endo(met, napprox)) local_samples = [] ... ...
 ... ... @@ -484,11 +484,9 @@ class StandardHamiltonian(EnergyOperator): ``_ """ def __init__(self, lh, ic_samp=None, _c_inp=None, prior_dtype=np.float64): def __init__(self, lh, ic_samp=None, prior_dtype=np.float64): self._lh = lh self._prior = GaussianEnergy(domain=lh.domain, sampling_dtype=prior_dtype) if _c_inp is not None: _, self._prior = self._prior.simplify_for_constant_input(_c_inp) self._ic_samp = ic_samp self._domain = lh.domain ... ... @@ -504,9 +502,9 @@ class StandardHamiltonian(EnergyOperator): subs += '\nPrior:\n{}'.format(self._prior) return 'StandardHamiltonian:\n' + utilities.indent(subs) # def _simplify_for_constant_input_nontrivial(self, c_inp): # out, lh1 = self._lh.simplify_for_constant_input(c_inp) # return out, StandardHamiltonian(lh1, self._ic_samp, _c_inp=c_inp) def _simplify_for_constant_input_nontrivial(self, c_inp): out, lh1 = self._lh.simplify_for_constant_input(c_inp) return out, StandardHamiltonian(lh1, self._ic_samp) class AveragedEnergy(EnergyOperator): ... ...
 ... ... @@ -285,7 +285,7 @@ class Operator(metaclass=NiftyMeta): # subdomain of self._domain if isinstance(self.domain, MultiDomain): assert isinstance(dom, MultiDomain) if set(c_inp.keys()) > set(self.domain.keys()): if not set(c_inp.keys()) <= set(self.domain.keys()): raise ValueError if dom is self.domain: ... ...
 ... ... @@ -15,8 +15,6 @@ # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. import numpy as np from .diagonal_operator import DiagonalOperator from .endomorphic_operator import EndomorphicOperator from .linear_operator import LinearOperator ... ... @@ -56,11 +54,15 @@ class SandwichOperator(EndomorphicOperator): old_cheese = cheese cheese = old_cheese._cheese bun = old_cheese._bun @ bun if not isinstance(bun, LinearOperator): raise TypeError("bun must be a linear operator") if isinstance(bun, ScalingOperator): return cheese.scale(bun._factor**2) if cheese is not None and not isinstance(cheese, LinearOperator): raise TypeError("cheese must be a linear operator or None") if cheese is None: # FIXME Sampling dtype not clear in this case cheese = ScalingOperator(bun.target, 1.) op = bun.adjoint(bun) else: ... ...
 ... ... @@ -334,9 +334,9 @@ class NullOperator(LinearOperator): @staticmethod def _nullfield(dom): if isinstance(dom, DomainTuple): return Field(dom, 0) return Field(dom, 0.) else: return MultiField.full(dom, 0) return MultiField.full(dom, 0.) def apply(self, x, mode): self._check_input(x, mode) ... ...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!