Commit da08f07e authored by Martin Reinecke's avatar Martin Reinecke

test for partial inference

parent 769812bf
......@@ -150,4 +150,5 @@ class Linearization(object):
@staticmethod
def make_const(field):
from .operators.simple_linear_operators import NullOperator
return Linearization(field, NullOperator({}, field.domain))
from .operators.scaling_operator import ScalingOperator
return Linearization(field, ScalingOperator(0., field.domain))
......@@ -3,23 +3,36 @@ from __future__ import absolute_import, division, print_function
from ..compat import *
from ..minimization.energy import Energy
from ..linearization import Linearization
from ..multi_field import MultiField
import numpy as np
class EnergyAdapter(Energy):
def __init__(self, position, op, controller=None, preconditioner=None):
def __init__(self, position, op, controller=None, preconditioner=None,
constants=[]):
super(EnergyAdapter, self).__init__(position)
self._op = op
self._val = self._grad = self._metric = None
self._controller = controller
self._preconditioner = preconditioner
self._constants = constants
def at(self, position):
return EnergyAdapter(position, self._op, self._controller,
self._preconditioner)
self._preconditioner, self._constants)
def _fill_all(self):
tmp = self._op(Linearization.make_var(self._position))
if len(self._constants) == 0:
tmp = self._op(Linearization.make_var(self._position))
else:
ctmp = MultiField.from_dict({key: val
for key, val in self._position.items()
if key in self._constants})
vtmp = MultiField.from_dict({key: val
for key, val in self._position.items()
if key not in self._constants})
lin = Linearization.make_var(vtmp) + Linearization.make_const(ctmp)
tmp = self._op(lin)
self._val = tmp.val.local_data[()]
self._grad = tmp.gradient
if self._controller is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment