linearization.py 4.99 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
from __future__ import absolute_import, division, print_function

import numpy as np

from .compat import *
from .field import Field
Martin Reinecke's avatar
Martin Reinecke committed
7
from .multi_field import MultiField
Martin Reinecke's avatar
Martin Reinecke committed
8
9
10
11
12
13
14
from .sugar import makeOp


class Linearization(object):
    def __init__(self, val, jac, metric=None):
        self._val = val
        self._jac = jac
Martin Reinecke's avatar
Martin Reinecke committed
15
16
        if self._val.domain != self._jac.target:
            raise ValueError("domain mismatch")
Martin Reinecke's avatar
Martin Reinecke committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
        self._metric = metric

    @property
    def domain(self):
        return self._jac.domain

    @property
    def target(self):
        return self._jac.target

    @property
    def val(self):
        return self._val

    @property
    def jac(self):
        return self._jac

Martin Reinecke's avatar
Martin Reinecke committed
35
36
37
    @property
    def gradient(self):
        """Only available if target is a scalar"""
Martin Reinecke's avatar
Martin Reinecke committed
38
        return self._jac.adjoint_times(Field.scalar(1.))
Martin Reinecke's avatar
Martin Reinecke committed
39

Martin Reinecke's avatar
Martin Reinecke committed
40
41
    @property
    def metric(self):
Martin Reinecke's avatar
Martin Reinecke committed
42
        """Only available if target is a scalar"""
Martin Reinecke's avatar
Martin Reinecke committed
43
44
        return self._metric

Martin Reinecke's avatar
Martin Reinecke committed
45
    def __getitem__(self, name):
Martin Reinecke's avatar
Martin Reinecke committed
46
        from .operators.simple_linear_operators import FieldAdapter
Martin Reinecke's avatar
Martin Reinecke committed
47
        return Linearization(self._val[name], FieldAdapter(self.domain, name))
Martin Reinecke's avatar
Martin Reinecke committed
48

Martin Reinecke's avatar
Martin Reinecke committed
49
    def __neg__(self):
Martin Reinecke's avatar
Martin Reinecke committed
50
        return Linearization(
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
51
52
            -self._val, -self._jac,
            None if self._metric is None else -self._metric)
Martin Reinecke's avatar
Martin Reinecke committed
53

Martin Reinecke's avatar
Martin Reinecke committed
54
55
56
57
58
59
60
61
62
    def conjugate(self):
        return Linearization(
            self._val.conjugate(), self._jac.conjugate(),
            None if self._metric is None else self._metric.conjugate())

    @property
    def real(self):
        return Linearization(self._val.real, self._jac.real)

Martin Reinecke's avatar
Martin Reinecke committed
63
    def _myadd(self, other, neg):
Martin Reinecke's avatar
Martin Reinecke committed
64
65
66
        if isinstance(other, Linearization):
            met = None
            if self._metric is not None and other._metric is not None:
Martin Reinecke's avatar
Martin Reinecke committed
67
                met = self._metric._myadd(other._metric, neg)
Martin Reinecke's avatar
Martin Reinecke committed
68
            return Linearization(
Martin Reinecke's avatar
Martin Reinecke committed
69
70
                self._val.flexible_addsub(other._val, neg),
                self._jac._myadd(other._jac, neg), met)
Martin Reinecke's avatar
Martin Reinecke committed
71
        if isinstance(other, (int, float, complex, Field, MultiField)):
Martin Reinecke's avatar
Martin Reinecke committed
72
73
74
75
76
77
78
            if neg:
                return Linearization(self._val-other, self._jac, self._metric)
            else:
                return Linearization(self._val+other, self._jac, self._metric)

    def __add__(self, other):
        return self._myadd(other, False)
Martin Reinecke's avatar
Martin Reinecke committed
79
80

    def __radd__(self, other):
Martin Reinecke's avatar
Martin Reinecke committed
81
        return self._myadd(other, False)
Martin Reinecke's avatar
Martin Reinecke committed
82
83

    def __sub__(self, other):
Martin Reinecke's avatar
Martin Reinecke committed
84
        return self._myadd(other, True)
Martin Reinecke's avatar
Martin Reinecke committed
85
86
87
88
89
90
91

    def __rsub__(self, other):
        return (-self).__add__(other)

    def __mul__(self, other):
        from .sugar import makeOp
        if isinstance(other, Linearization):
Martin Reinecke's avatar
Martin Reinecke committed
92
93
            if self.target != other.target:
                raise ValueError("domain mismatch")
Martin Reinecke's avatar
Martin Reinecke committed
94
95
            return Linearization(
                self._val*other._val,
Martin Reinecke's avatar
Martin Reinecke committed
96
97
                (makeOp(other._val)(self._jac))._myadd(
                 makeOp(self._val)(other._jac), False))
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
98
99
100
101
102
        if np.isscalar(other):
            if other == 1:
                return self
            met = None if self._metric is None else self._metric.scale(other)
            return Linearization(self._val*other, self._jac.scale(other), met)
Martin Reinecke's avatar
Martin Reinecke committed
103
        if isinstance(other, (Field, MultiField)):
Martin Reinecke's avatar
Martin Reinecke committed
104
105
            if self.target != other.domain:
                raise ValueError("domain mismatch")
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
106
            return Linearization(self._val*other, makeOp(other)(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
107
108

    def __rmul__(self, other):
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
109
        return self.__mul__(other)
Martin Reinecke's avatar
Martin Reinecke committed
110

Martin Reinecke's avatar
Martin Reinecke committed
111
    def vdot(self, other):
Martin Reinecke's avatar
Martin Reinecke committed
112
        from .operators.simple_linear_operators import VdotOperator
Martin Reinecke's avatar
Martin Reinecke committed
113
114
        if isinstance(other, (Field, MultiField)):
            return Linearization(
Martin Reinecke's avatar
Martin Reinecke committed
115
                Field.scalar(self._val.vdot(other)),
Martin Reinecke's avatar
Martin Reinecke committed
116
                VdotOperator(other)(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
117
        return Linearization(
Martin Reinecke's avatar
Martin Reinecke committed
118
            Field.scalar(self._val.vdot(other._val)),
Martin Reinecke's avatar
Martin Reinecke committed
119
120
            VdotOperator(self._val)(other._jac) +
            VdotOperator(other._val)(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
121

Martin Reinecke's avatar
Martin Reinecke committed
122
    def sum(self):
Martin Reinecke's avatar
Martin Reinecke committed
123
        from .operators.simple_linear_operators import SumReductionOperator
124
        return Linearization(
Martin Reinecke's avatar
Martin Reinecke committed
125
            Field.scalar(self._val.sum()),
Martin Reinecke's avatar
Martin Reinecke committed
126
            SumReductionOperator(self._jac.target)(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
127
128
129

    def exp(self):
        tmp = self._val.exp()
Martin Reinecke's avatar
Martin Reinecke committed
130
        return Linearization(tmp, makeOp(tmp)(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
131
132
133

    def log(self):
        tmp = self._val.log()
Martin Reinecke's avatar
Martin Reinecke committed
134
        return Linearization(tmp, makeOp(1./self._val)(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
135

Martin Reinecke's avatar
Martin Reinecke committed
136
137
    def tanh(self):
        tmp = self._val.tanh()
Martin Reinecke's avatar
Martin Reinecke committed
138
        return Linearization(tmp, makeOp(1.-tmp**2)(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
139
140
141
142

    def positive_tanh(self):
        tmp = self._val.tanh()
        tmp2 = 0.5*(1.+tmp)
Martin Reinecke's avatar
Martin Reinecke committed
143
        return Linearization(tmp2, makeOp(0.5*(1.-tmp**2))(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
144

Martin Reinecke's avatar
Martin Reinecke committed
145
146
147
148
149
150
151
152
153
154
    def add_metric(self, metric):
        return Linearization(self._val, self._jac, metric)

    @staticmethod
    def make_var(field):
        from .operators.scaling_operator import ScalingOperator
        return Linearization(field, ScalingOperator(1., field.domain))

    @staticmethod
    def make_const(field):
Martin Reinecke's avatar
Martin Reinecke committed
155
        from .operators.simple_linear_operators import NullOperator
Martin Reinecke's avatar
Martin Reinecke committed
156
        return Linearization(field, NullOperator(field.domain, field.domain))