linearization.py 4.89 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from __future__ import absolute_import, division, print_function

import numpy as np

from .compat import *
from .field import Field
from .multi.multi_field import MultiField
from .sugar import makeOp


class Linearization(object):
    def __init__(self, val, jac, metric=None):
        self._val = val
        self._jac = jac
        self._metric = metric

    @property
    def domain(self):
        return self._jac.domain

    @property
    def target(self):
        return self._jac.target

    @property
    def val(self):
        return self._val

    @property
    def jac(self):
        return self._jac

Martin Reinecke's avatar
Martin Reinecke committed
33
34
35
    @property
    def gradient(self):
        """Only available if target is a scalar"""
Martin Reinecke's avatar
Martin Reinecke committed
36
        return self._jac.adjoint_times(Field(self._jac.target, 1.))
Martin Reinecke's avatar
Martin Reinecke committed
37

Martin Reinecke's avatar
Martin Reinecke committed
38
39
    @property
    def metric(self):
Martin Reinecke's avatar
Martin Reinecke committed
40
        """Only available if target is a scalar"""
Martin Reinecke's avatar
Martin Reinecke committed
41
42
        return self._metric

Martin Reinecke's avatar
Martin Reinecke committed
43
44
    def __getitem__(self, name):
        from .operators.field_adapter import FieldAdapter
Martin Reinecke's avatar
Martin Reinecke committed
45
        return Linearization(self._val[name], FieldAdapter(self.domain, name))
Martin Reinecke's avatar
Martin Reinecke committed
46

Martin Reinecke's avatar
Martin Reinecke committed
47
    def __neg__(self):
Martin Reinecke's avatar
Martin Reinecke committed
48
        return Linearization(
49
50
            -self._val, self._jac.chain(-1),
            None if self._metric is None else self._metric.chain(-1))
Martin Reinecke's avatar
Martin Reinecke committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77

    def __add__(self, other):
        if isinstance(other, Linearization):
            from .operators.relaxed_sum_operator import RelaxedSumOperator
            met = None
            if self._metric is not None and other._metric is not None:
                met = RelaxedSumOperator((self._metric, other._metric))
            return Linearization(
                self._val.unite(other._val),
                RelaxedSumOperator((self._jac, other._jac)), met)
        if isinstance(other, (int, float, complex, Field, MultiField)):
            return Linearization(self._val+other, self._jac, self._metric)

    def __radd__(self, other):
        return self.__add__(other)

    def __sub__(self, other):
        return self.__add__(-other)

    def __rsub__(self, other):
        return (-self).__add__(other)

    def __mul__(self, other):
        from .sugar import makeOp
        if isinstance(other, Linearization):
            d1 = makeOp(self._val)
            d2 = makeOp(other._val)
Martin Reinecke's avatar
Martin Reinecke committed
78
79
            return Linearization(
                self._val*other._val,
Martin Reinecke's avatar
Martin Reinecke committed
80
                d2.chain(self._jac) + d1.chain(other._jac))
Martin Reinecke's avatar
Martin Reinecke committed
81
82
83
        if isinstance(other, (int, float, complex)):
            # if other == 0:
            #     return ...
84
85
            met = None if self._metric is None else self._metric.chain(other)
            return Linearization(self._val*other, self._jac.chain(other), met)
Martin Reinecke's avatar
Martin Reinecke committed
86
87
        if isinstance(other, (Field, MultiField)):
            d2 = makeOp(other)
88
            return Linearization(self._val*other, d2.chain(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
89
90
91
92
93
        raise TypeError

    def __rmul__(self, other):
        from .sugar import makeOp
        if isinstance(other, (int, float, complex)):
94
            return Linearization(self._val*other, self._jac.chain(other))
Martin Reinecke's avatar
Martin Reinecke committed
95
96
        if isinstance(other, (Field, MultiField)):
            d1 = makeOp(other)
97
            return Linearization(self._val*other, d1.chain(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
98

Martin Reinecke's avatar
Martin Reinecke committed
99
100
101
102
103
104
105
106
107
108
109
110
    def vdot(self, other):
        from .domain_tuple import DomainTuple
        from .operators.vdot_operator import VdotOperator
        if isinstance(other, (Field, MultiField)):
            return Linearization(
                Field(DomainTuple.scalar_domain(),self._val.vdot(other)),
                VdotOperator(other).chain(self._jac))
        return Linearization(
            Field(DomainTuple.scalar_domain(),self._val.vdot(other._val)),
            VdotOperator(self._val).chain(other._jac) +
            VdotOperator(other._val).chain(self._jac))

Martin Reinecke's avatar
Martin Reinecke committed
111
    def sum(self):
Martin Reinecke's avatar
Martin Reinecke committed
112
113
        from .domain_tuple import DomainTuple
        from .operators.vdot_operator import SumReductionOperator
Martin Reinecke's avatar
Martin Reinecke committed
114
        from .sugar import full
115
        return Linearization(
Martin Reinecke's avatar
Martin Reinecke committed
116
117
            Field(DomainTuple.scalar_domain(), self._val.sum()),
            SumReductionOperator(self._jac.target).chain(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
118
119
120

    def exp(self):
        tmp = self._val.exp()
121
        return Linearization(tmp, makeOp(tmp).chain(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
122
123
124

    def log(self):
        tmp = self._val.log()
125
        return Linearization(tmp, makeOp(1./self._val).chain(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
126

Martin Reinecke's avatar
Martin Reinecke committed
127
128
    def tanh(self):
        tmp = self._val.tanh()
129
        return Linearization(tmp, makeOp(1.-tmp**2).chain(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
130
131
132
133

    def positive_tanh(self):
        tmp = self._val.tanh()
        tmp2 = 0.5*(1.+tmp)
134
        return Linearization(tmp2, makeOp(0.5*(1.-tmp**2)).chain(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
135

Martin Reinecke's avatar
Martin Reinecke committed
136
137
138
139
140
141
142
143
144
145
146
147
    def add_metric(self, metric):
        return Linearization(self._val, self._jac, metric)

    @staticmethod
    def make_var(field):
        from .operators.scaling_operator import ScalingOperator
        return Linearization(field, ScalingOperator(1., field.domain))

    @staticmethod
    def make_const(field):
        from .operators.null_operator import NullOperator
        return Linearization(field, NullOperator({}, field.domain))