linearization.py 3.46 KB
 Martin Reinecke committed Jul 26, 2018 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 ``````from __future__ import absolute_import, division, print_function import numpy as np from .compat import * from .field import Field from .multi.multi_field import MultiField from .sugar import makeOp class Linearization(object): def __init__(self, val, jac, metric=None): self._val = val self._jac = jac self._metric = metric @property def domain(self): return self._jac.domain @property def target(self): return self._jac.target @property def val(self): return self._val @property def jac(self): return self._jac @property def metric(self): return self._metric def __neg__(self): return Linearization(-self._val, self._jac*(-1), None if self._metric is None else self._metric*(-1)) def __add__(self, other): if isinstance(other, Linearization): from .operators.relaxed_sum_operator import RelaxedSumOperator met = None if self._metric is not None and other._metric is not None: met = RelaxedSumOperator((self._metric, other._metric)) return Linearization( self._val.unite(other._val), RelaxedSumOperator((self._jac, other._jac)), met) if isinstance(other, (int, float, complex, Field, MultiField)): return Linearization(self._val+other, self._jac, self._metric) def __radd__(self, other): return self.__add__(other) def __sub__(self, other): return self.__add__(-other) def __rsub__(self, other): return (-self).__add__(other) def __mul__(self, other): from .sugar import makeOp if isinstance(other, Linearization): d1 = makeOp(self._val) d2 = makeOp(other._val) return Linearization(self._val*other._val, d2*self._jac + d1*other._jac) if isinstance(other, (int, float, complex)): # if other == 0: # return ... return Linearization(self._val*other, self._jac*other) if isinstance(other, (Field, MultiField)): d2 = makeOp(other) return Linearization(self._val*other, self._jac*d2) raise TypeError def __rmul__(self, other): from .sugar import makeOp if isinstance(other, (int, float, complex)): return Linearization(self._val*other, self._jac*other) if isinstance(other, (Field, MultiField)): d1 = makeOp(other) return Linearization(self._val*other, d1*self._jac) def sum(self): from .sugar import full from .operators.vdot_operator import VdotOperator return Linearization(full((),self._val.sum()), VdotOperator(full(self._jac.target,1))*self._jac) def exp(self): tmp = self._val.exp() return Linearization(tmp, makeOp(tmp)*self._jac) def log(self): tmp = self._val.log() return Linearization(tmp, makeOp(1./self._val)*self._jac) def add_metric(self, metric): return Linearization(self._val, self._jac, metric) @staticmethod def make_var(field): from .operators.scaling_operator import ScalingOperator return Linearization(field, ScalingOperator(1., field.domain)) @staticmethod def make_const(field): from .operators.null_operator import NullOperator return Linearization(field, NullOperator({}, field.domain)) ``````