linearization.py 4.32 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from __future__ import absolute_import, division, print_function

import numpy as np

from .compat import *
from .field import Field
from .multi.multi_field import MultiField
from .sugar import makeOp


class Linearization(object):
    def __init__(self, val, jac, metric=None):
        self._val = val
        self._jac = jac
        self._metric = metric

    @property
    def domain(self):
        return self._jac.domain

    @property
    def target(self):
        return self._jac.target

    @property
    def val(self):
        return self._val

    @property
    def jac(self):
        return self._jac

Martin Reinecke's avatar
Martin Reinecke committed
33
34
35
36
37
38
    @property
    def gradient(self):
        """Only available if target is a scalar"""
        from .sugar import full
        return self._jac.adjoint_times(full(self._jac.target, 1.))

Martin Reinecke's avatar
Martin Reinecke committed
39
40
    @property
    def metric(self):
Martin Reinecke's avatar
Martin Reinecke committed
41
        """Only available if target is a scalar"""
Martin Reinecke's avatar
Martin Reinecke committed
42
43
        return self._metric

Martin Reinecke's avatar
Martin Reinecke committed
44
45
46
47
48
    def __getitem__(self, name):
        from .operators.field_adapter import FieldAdapter
        dom = self._val[name].domain
        return Linearization(self._val[name], FieldAdapter(dom, name))

Martin Reinecke's avatar
Martin Reinecke committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    def __neg__(self):
        return Linearization(-self._val, self._jac*(-1),
                             None if self._metric is None else self._metric*(-1))

    def __add__(self, other):
        if isinstance(other, Linearization):
            from .operators.relaxed_sum_operator import RelaxedSumOperator
            met = None
            if self._metric is not None and other._metric is not None:
                met = RelaxedSumOperator((self._metric, other._metric))
            return Linearization(
                self._val.unite(other._val),
                RelaxedSumOperator((self._jac, other._jac)), met)
        if isinstance(other, (int, float, complex, Field, MultiField)):
            return Linearization(self._val+other, self._jac, self._metric)

    def __radd__(self, other):
        return self.__add__(other)

    def __sub__(self, other):
        return self.__add__(-other)

    def __rsub__(self, other):
        return (-self).__add__(other)

    def __mul__(self, other):
        from .sugar import makeOp
Martin Reinecke's avatar
Martin Reinecke committed
76
        from .operators.relaxed_sum_operator import RelaxedSumOperator
Martin Reinecke's avatar
Martin Reinecke committed
77
78
79
80
        if isinstance(other, Linearization):
            d1 = makeOp(self._val)
            d2 = makeOp(other._val)
            return Linearization(self._val*other._val,
Martin Reinecke's avatar
Martin Reinecke committed
81
                                 RelaxedSumOperator((d2*self._jac, d1*other._jac)))
Martin Reinecke's avatar
Martin Reinecke committed
82
83
84
        if isinstance(other, (int, float, complex)):
            # if other == 0:
            #     return ...
Martin Reinecke's avatar
Martin Reinecke committed
85
86
            met = None if self._metric is None else self._metric*other
            return Linearization(self._val*other, self._jac*other, met)
Martin Reinecke's avatar
Martin Reinecke committed
87
88
        if isinstance(other, (Field, MultiField)):
            d2 = makeOp(other)
Martin Reinecke's avatar
bug fix    
Martin Reinecke committed
89
            return Linearization(self._val*other, d2*self._jac)
Martin Reinecke's avatar
Martin Reinecke committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        raise TypeError

    def __rmul__(self, other):
        from .sugar import makeOp
        if isinstance(other, (int, float, complex)):
            return Linearization(self._val*other, self._jac*other)
        if isinstance(other, (Field, MultiField)):
            d1 = makeOp(other)
            return Linearization(self._val*other, d1*self._jac)

    def sum(self):
        from .sugar import full
        from .operators.vdot_operator import VdotOperator
        return Linearization(full((),self._val.sum()),
                             VdotOperator(full(self._jac.target,1))*self._jac)

    def exp(self):
        tmp = self._val.exp()
        return Linearization(tmp, makeOp(tmp)*self._jac)

    def log(self):
        tmp = self._val.log()
        return Linearization(tmp, makeOp(1./self._val)*self._jac)

Martin Reinecke's avatar
Martin Reinecke committed
114
115
116
117
118
119
120
121
122
    def tanh(self):
        tmp = self._val.tanh()
        return Linearization(tmp, makeOp(1.-tmp**2)*self._jac)

    def positive_tanh(self):
        tmp = self._val.tanh()
        tmp2 = 0.5*(1.+tmp)
        return Linearization(tmp2, makeOp(0.5*(1.-tmp**2))*self._jac)

Martin Reinecke's avatar
Martin Reinecke committed
123
124
125
126
127
128
129
130
131
132
133
134
135
    def add_metric(self, metric):
        return Linearization(self._val, self._jac, metric)

    @staticmethod
    def make_var(field):
        from .operators.scaling_operator import ScalingOperator
        return Linearization(field, ScalingOperator(1., field.domain))

    @staticmethod
    def make_const(field):
        from .operators.null_operator import NullOperator
        return Linearization(field, NullOperator({}, field.domain))