linearization.py 4.29 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
from __future__ import absolute_import, division, print_function

import numpy as np

from .compat import *
from .field import Field
from .multi.multi_field import MultiField
from .sugar import makeOp


class Linearization(object):
    def __init__(self, val, jac, metric=None):
        self._val = val
        self._jac = jac
        self._metric = metric

    @property
    def domain(self):
        return self._jac.domain

    @property
    def target(self):
        return self._jac.target

    @property
    def val(self):
        return self._val

    @property
    def jac(self):
        return self._jac

Martin Reinecke's avatar
Martin Reinecke committed
33 34 35 36 37 38
    @property
    def gradient(self):
        """Only available if target is a scalar"""
        from .sugar import full
        return self._jac.adjoint_times(full(self._jac.target, 1.))

Martin Reinecke's avatar
Martin Reinecke committed
39 40
    @property
    def metric(self):
Martin Reinecke's avatar
Martin Reinecke committed
41
        """Only available if target is a scalar"""
Martin Reinecke's avatar
Martin Reinecke committed
42 43
        return self._metric

Martin Reinecke's avatar
Martin Reinecke committed
44 45
    def __getitem__(self, name):
        from .operators.field_adapter import FieldAdapter
Martin Reinecke's avatar
Martin Reinecke committed
46
        return Linearization(self._val[name], FieldAdapter(self.domain, name))
Martin Reinecke's avatar
Martin Reinecke committed
47

Martin Reinecke's avatar
Martin Reinecke committed
48
    def __neg__(self):
Martin Reinecke's avatar
Martin Reinecke committed
49 50 51
        return Linearization(
            -self._val, self._jac*(-1),
            None if self._metric is None else self._metric*(-1))
Martin Reinecke's avatar
Martin Reinecke committed
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75

    def __add__(self, other):
        if isinstance(other, Linearization):
            from .operators.relaxed_sum_operator import RelaxedSumOperator
            met = None
            if self._metric is not None and other._metric is not None:
                met = RelaxedSumOperator((self._metric, other._metric))
            return Linearization(
                self._val.unite(other._val),
                RelaxedSumOperator((self._jac, other._jac)), met)
        if isinstance(other, (int, float, complex, Field, MultiField)):
            return Linearization(self._val+other, self._jac, self._metric)

    def __radd__(self, other):
        return self.__add__(other)

    def __sub__(self, other):
        return self.__add__(-other)

    def __rsub__(self, other):
        return (-self).__add__(other)

    def __mul__(self, other):
        from .sugar import makeOp
Martin Reinecke's avatar
Martin Reinecke committed
76
        from .operators.relaxed_sum_operator import RelaxedSumOperator
Martin Reinecke's avatar
Martin Reinecke committed
77 78 79
        if isinstance(other, Linearization):
            d1 = makeOp(self._val)
            d2 = makeOp(other._val)
Martin Reinecke's avatar
Martin Reinecke committed
80 81 82
            return Linearization(
                self._val*other._val,
                RelaxedSumOperator((d2*self._jac, d1*other._jac)))
Martin Reinecke's avatar
Martin Reinecke committed
83 84 85
        if isinstance(other, (int, float, complex)):
            # if other == 0:
            #     return ...
Martin Reinecke's avatar
Martin Reinecke committed
86 87
            met = None if self._metric is None else self._metric*other
            return Linearization(self._val*other, self._jac*other, met)
Martin Reinecke's avatar
Martin Reinecke committed
88 89
        if isinstance(other, (Field, MultiField)):
            d2 = makeOp(other)
Martin Reinecke's avatar
bug fix  
Martin Reinecke committed
90
            return Linearization(self._val*other, d2*self._jac)
Martin Reinecke's avatar
Martin Reinecke committed
91 92 93 94 95 96 97 98 99 100 101 102 103
        raise TypeError

    def __rmul__(self, other):
        from .sugar import makeOp
        if isinstance(other, (int, float, complex)):
            return Linearization(self._val*other, self._jac*other)
        if isinstance(other, (Field, MultiField)):
            d1 = makeOp(other)
            return Linearization(self._val*other, d1*self._jac)

    def sum(self):
        from .sugar import full
        from .operators.vdot_operator import VdotOperator
Martin Reinecke's avatar
Martin Reinecke committed
104 105
        return Linearization(full((), self._val.sum()),
                             VdotOperator(full(self._jac.target, 1))*self._jac)
Martin Reinecke's avatar
Martin Reinecke committed
106 107 108 109 110 111 112 113 114

    def exp(self):
        tmp = self._val.exp()
        return Linearization(tmp, makeOp(tmp)*self._jac)

    def log(self):
        tmp = self._val.log()
        return Linearization(tmp, makeOp(1./self._val)*self._jac)

Martin Reinecke's avatar
Martin Reinecke committed
115 116 117 118 119 120 121 122 123
    def tanh(self):
        tmp = self._val.tanh()
        return Linearization(tmp, makeOp(1.-tmp**2)*self._jac)

    def positive_tanh(self):
        tmp = self._val.tanh()
        tmp2 = 0.5*(1.+tmp)
        return Linearization(tmp2, makeOp(0.5*(1.-tmp**2))*self._jac)

Martin Reinecke's avatar
Martin Reinecke committed
124 125 126 127 128 129 130 131 132 133 134 135
    def add_metric(self, metric):
        return Linearization(self._val, self._jac, metric)

    @staticmethod
    def make_var(field):
        from .operators.scaling_operator import ScalingOperator
        return Linearization(field, ScalingOperator(1., field.domain))

    @staticmethod
    def make_const(field):
        from .operators.null_operator import NullOperator
        return Linearization(field, NullOperator({}, field.domain))