Commit ca8ec3cc by Philipp Arras

### Implement proper constant support 3/n

parent 2ffc1226
 ... @@ -184,7 +184,8 @@ class MetricGaussianKL(Energy): ... @@ -184,7 +184,8 @@ class MetricGaussianKL(Energy): _, ham_sampling = hamiltonian.simplify_for_constant_input(cstpos) _, ham_sampling = hamiltonian.simplify_for_constant_input(cstpos) else: else: ham_sampling = hamiltonian ham_sampling = hamiltonian met = ham_sampling(Linearization.make_var(mean.extract(ham_sampling.domain), True)).metric lin = Linearization.make_var(mean.extract(ham_sampling.domain), True) met = ham_sampling(lin).metric if napprox >= 1: if napprox >= 1: met._approximation = makeOp(approximation2endo(met, napprox)) met._approximation = makeOp(approximation2endo(met, napprox)) local_samples = [] local_samples = [] ... ...
 ... @@ -484,11 +484,9 @@ class StandardHamiltonian(EnergyOperator): ... @@ -484,11 +484,9 @@ class StandardHamiltonian(EnergyOperator): ``_ ``_ """ """ def __init__(self, lh, ic_samp=None, _c_inp=None, prior_dtype=np.float64): def __init__(self, lh, ic_samp=None, prior_dtype=np.float64): self._lh = lh self._lh = lh self._prior = GaussianEnergy(domain=lh.domain, sampling_dtype=prior_dtype) self._prior = GaussianEnergy(domain=lh.domain, sampling_dtype=prior_dtype) if _c_inp is not None: _, self._prior = self._prior.simplify_for_constant_input(_c_inp) self._ic_samp = ic_samp self._ic_samp = ic_samp self._domain = lh.domain self._domain = lh.domain ... @@ -504,9 +502,9 @@ class StandardHamiltonian(EnergyOperator): ... @@ -504,9 +502,9 @@ class StandardHamiltonian(EnergyOperator): subs += '\nPrior:\n{}'.format(self._prior) subs += '\nPrior:\n{}'.format(self._prior) return 'StandardHamiltonian:\n' + utilities.indent(subs) return 'StandardHamiltonian:\n' + utilities.indent(subs) # def _simplify_for_constant_input_nontrivial(self, c_inp): def _simplify_for_constant_input_nontrivial(self, c_inp): # out, lh1 = self._lh.simplify_for_constant_input(c_inp) out, lh1 = self._lh.simplify_for_constant_input(c_inp) # return out, StandardHamiltonian(lh1, self._ic_samp, _c_inp=c_inp) return out, StandardHamiltonian(lh1, self._ic_samp) class AveragedEnergy(EnergyOperator): class AveragedEnergy(EnergyOperator): ... ...
 ... @@ -285,7 +285,7 @@ class Operator(metaclass=NiftyMeta): ... @@ -285,7 +285,7 @@ class Operator(metaclass=NiftyMeta): # subdomain of self._domain # subdomain of self._domain if isinstance(self.domain, MultiDomain): if isinstance(self.domain, MultiDomain): assert isinstance(dom, MultiDomain) assert isinstance(dom, MultiDomain) if set(c_inp.keys()) > set(self.domain.keys()): if not set(c_inp.keys()) <= set(self.domain.keys()): raise ValueError raise ValueError if dom is self.domain: if dom is self.domain: ... ...
 ... @@ -15,8 +15,6 @@ ... @@ -15,8 +15,6 @@ # # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. import numpy as np from .diagonal_operator import DiagonalOperator from .diagonal_operator import DiagonalOperator from .endomorphic_operator import EndomorphicOperator from .endomorphic_operator import EndomorphicOperator from .linear_operator import LinearOperator from .linear_operator import LinearOperator ... @@ -56,11 +54,15 @@ class SandwichOperator(EndomorphicOperator): ... @@ -56,11 +54,15 @@ class SandwichOperator(EndomorphicOperator): old_cheese = cheese old_cheese = cheese cheese = old_cheese._cheese cheese = old_cheese._cheese bun = old_cheese._bun @ bun bun = old_cheese._bun @ bun if not isinstance(bun, LinearOperator): if not isinstance(bun, LinearOperator): raise TypeError("bun must be a linear operator") raise TypeError("bun must be a linear operator") if isinstance(bun, ScalingOperator): return cheese.scale(bun._factor**2) if cheese is not None and not isinstance(cheese, LinearOperator): if cheese is not None and not isinstance(cheese, LinearOperator): raise TypeError("cheese must be a linear operator or None") raise TypeError("cheese must be a linear operator or None") if cheese is None: if cheese is None: # FIXME Sampling dtype not clear in this case cheese = ScalingOperator(bun.target, 1.) cheese = ScalingOperator(bun.target, 1.) op = bun.adjoint(bun) op = bun.adjoint(bun) else: else: ... ...
 ... @@ -334,9 +334,9 @@ class NullOperator(LinearOperator): ... @@ -334,9 +334,9 @@ class NullOperator(LinearOperator): @staticmethod @staticmethod def _nullfield(dom): def _nullfield(dom): if isinstance(dom, DomainTuple): if isinstance(dom, DomainTuple): return Field(dom, 0) return Field(dom, 0.) else: else: return MultiField.full(dom, 0) return MultiField.full(dom, 0.) def apply(self, x, mode): def apply(self, x, mode): self._check_input(x, mode) self._check_input(x, mode) ... ...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment