linearization.py 3.7 KB
 Martin Reinecke committed Jul 26, 2018 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 ``````from __future__ import absolute_import, division, print_function import numpy as np from .compat import * from .field import Field from .multi.multi_field import MultiField from .sugar import makeOp class Linearization(object): def __init__(self, val, jac, metric=None): self._val = val self._jac = jac self._metric = metric @property def domain(self): return self._jac.domain @property def target(self): return self._jac.target @property def val(self): return self._val @property def jac(self): return self._jac `````` Martin Reinecke committed Jul 26, 2018 33 34 35 36 37 38 `````` @property def gradient(self): """Only available if target is a scalar""" from .sugar import full return self._jac.adjoint_times(full(self._jac.target, 1.)) `````` Martin Reinecke committed Jul 26, 2018 39 40 `````` @property def metric(self): `````` Martin Reinecke committed Jul 26, 2018 41 `````` """Only available if target is a scalar""" `````` Martin Reinecke committed Jul 26, 2018 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 `````` return self._metric def __neg__(self): return Linearization(-self._val, self._jac*(-1), None if self._metric is None else self._metric*(-1)) def __add__(self, other): if isinstance(other, Linearization): from .operators.relaxed_sum_operator import RelaxedSumOperator met = None if self._metric is not None and other._metric is not None: met = RelaxedSumOperator((self._metric, other._metric)) return Linearization( self._val.unite(other._val), RelaxedSumOperator((self._jac, other._jac)), met) if isinstance(other, (int, float, complex, Field, MultiField)): return Linearization(self._val+other, self._jac, self._metric) def __radd__(self, other): return self.__add__(other) def __sub__(self, other): return self.__add__(-other) def __rsub__(self, other): return (-self).__add__(other) def __mul__(self, other): from .sugar import makeOp if isinstance(other, Linearization): d1 = makeOp(self._val) d2 = makeOp(other._val) return Linearization(self._val*other._val, d2*self._jac + d1*other._jac) if isinstance(other, (int, float, complex)): # if other == 0: # return ... return Linearization(self._val*other, self._jac*other) if isinstance(other, (Field, MultiField)): d2 = makeOp(other) `````` Martin Reinecke committed Jul 26, 2018 82 `````` return Linearization(self._val*other, d2*self._jac) `````` Martin Reinecke committed Jul 26, 2018 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 `````` raise TypeError def __rmul__(self, other): from .sugar import makeOp if isinstance(other, (int, float, complex)): return Linearization(self._val*other, self._jac*other) if isinstance(other, (Field, MultiField)): d1 = makeOp(other) return Linearization(self._val*other, d1*self._jac) def sum(self): from .sugar import full from .operators.vdot_operator import VdotOperator return Linearization(full((),self._val.sum()), VdotOperator(full(self._jac.target,1))*self._jac) def exp(self): tmp = self._val.exp() return Linearization(tmp, makeOp(tmp)*self._jac) def log(self): tmp = self._val.log() return Linearization(tmp, makeOp(1./self._val)*self._jac) def add_metric(self, metric): return Linearization(self._val, self._jac, metric) @staticmethod def make_var(field): from .operators.scaling_operator import ScalingOperator return Linearization(field, ScalingOperator(1., field.domain)) @staticmethod def make_const(field): from .operators.null_operator import NullOperator return Linearization(field, NullOperator({}, field.domain)) ``````