Commit a54c890a authored by Martin Reinecke's avatar Martin Reinecke

fix partial inference with Newton minimizers

parent c749042e
...@@ -4,6 +4,8 @@ from ..compat import * ...@@ -4,6 +4,8 @@ from ..compat import *
from ..minimization.energy import Energy from ..minimization.energy import Energy
from ..linearization import Linearization from ..linearization import Linearization
from ..multi_field import MultiField from ..multi_field import MultiField
from ..operators.scaling_operator import ScalingOperator
from ..operators.block_diagonal_operator import BlockDiagonalOperator
import numpy as np import numpy as np
...@@ -25,14 +27,10 @@ class EnergyAdapter(Energy): ...@@ -25,14 +27,10 @@ class EnergyAdapter(Energy):
if len(self._constants) == 0: if len(self._constants) == 0:
tmp = self._op(Linearization.make_var(self._position)) tmp = self._op(Linearization.make_var(self._position))
else: else:
ctmp = MultiField.from_dict({key: val ops = [ScalingOperator(0. if key in self._constants else 1., dom)
for key, val in self._position.items() for key, dom in self._position.domain.items()]
if key in self._constants}) bdop = BlockDiagonalOperator(self._position.domain, tuple(ops))
vtmp = MultiField.from_dict({key: val tmp = self._op(Linearization(self._position, bdop))
for key, val in self._position.items()
if key not in self._constants})
lin = Linearization.make_var(vtmp) + Linearization.make_const(ctmp)
tmp = self._op(lin)
self._val = tmp.val.local_data[()] self._val = tmp.val.local_data[()]
self._grad = tmp.gradient self._grad = tmp.gradient
if self._controller is not None: if self._controller is not None:
......
...@@ -159,8 +159,8 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -159,8 +159,8 @@ class DiagonalOperator(EndomorphicOperator):
def process_sample(self, samp, from_inverse): def process_sample(self, samp, from_inverse):
if (self._complex or (self._diagmin < 0.) or if (self._complex or (self._diagmin < 0.) or
(self._diagmin == 0. and from_inverse)): (self._diagmin == 0. and from_inverse)):
raise ValueError("operator not positive definite") raise ValueError("operator not positive definite")
if from_inverse: if from_inverse:
res = samp.local_data/np.sqrt(self._ldiag) res = samp.local_data/np.sqrt(self._ldiag)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment