Commit a54c890a by Martin Reinecke

### fix partial inference with Newton minimizers

parent c749042e
 ... @@ -4,6 +4,8 @@ from ..compat import * ... @@ -4,6 +4,8 @@ from ..compat import * from ..minimization.energy import Energy from ..minimization.energy import Energy from ..linearization import Linearization from ..linearization import Linearization from ..multi_field import MultiField from ..multi_field import MultiField from ..operators.scaling_operator import ScalingOperator from ..operators.block_diagonal_operator import BlockDiagonalOperator import numpy as np import numpy as np ... @@ -25,14 +27,10 @@ class EnergyAdapter(Energy): ... @@ -25,14 +27,10 @@ class EnergyAdapter(Energy): if len(self._constants) == 0: if len(self._constants) == 0: tmp = self._op(Linearization.make_var(self._position)) tmp = self._op(Linearization.make_var(self._position)) else: else: ctmp = MultiField.from_dict({key: val ops = [ScalingOperator(0. if key in self._constants else 1., dom) for key, val in self._position.items() for key, dom in self._position.domain.items()] if key in self._constants}) bdop = BlockDiagonalOperator(self._position.domain, tuple(ops)) vtmp = MultiField.from_dict({key: val tmp = self._op(Linearization(self._position, bdop)) for key, val in self._position.items() if key not in self._constants}) lin = Linearization.make_var(vtmp) + Linearization.make_const(ctmp) tmp = self._op(lin) self._val = tmp.val.local_data[()] self._val = tmp.val.local_data[()] self._grad = tmp.gradient self._grad = tmp.gradient if self._controller is not None: if self._controller is not None: ... ...
 ... @@ -159,8 +159,8 @@ class DiagonalOperator(EndomorphicOperator): ... @@ -159,8 +159,8 @@ class DiagonalOperator(EndomorphicOperator): def process_sample(self, samp, from_inverse): def process_sample(self, samp, from_inverse): if (self._complex or (self._diagmin < 0.) or if (self._complex or (self._diagmin < 0.) or (self._diagmin == 0. and from_inverse)): (self._diagmin == 0. and from_inverse)): raise ValueError("operator not positive definite") raise ValueError("operator not positive definite") if from_inverse: if from_inverse: res = samp.local_data/np.sqrt(self._ldiag) res = samp.local_data/np.sqrt(self._ldiag) else: else: ... ...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!