linearization.py 4.65 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
from __future__ import absolute_import, division, print_function

import numpy as np

from .compat import *
from .field import Field
Martin Reinecke's avatar
Martin Reinecke committed
7
from .multi_field import MultiField
Martin Reinecke's avatar
Martin Reinecke committed
8
from .sugar import makeOp
Martin Reinecke's avatar
tmp    
Martin Reinecke committed
9
from .domain_tuple import DomainTuple
Martin Reinecke's avatar
Martin Reinecke committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33


class Linearization(object):
    def __init__(self, val, jac, metric=None):
        self._val = val
        self._jac = jac
        self._metric = metric

    @property
    def domain(self):
        return self._jac.domain

    @property
    def target(self):
        return self._jac.target

    @property
    def val(self):
        return self._val

    @property
    def jac(self):
        return self._jac

Martin Reinecke's avatar
Martin Reinecke committed
34
35
36
    @property
    def gradient(self):
        """Only available if target is a scalar"""
Martin Reinecke's avatar
Martin Reinecke committed
37
        return self._jac.adjoint_times(Field.scalar(1.))
Martin Reinecke's avatar
Martin Reinecke committed
38

Martin Reinecke's avatar
Martin Reinecke committed
39
40
    @property
    def metric(self):
Martin Reinecke's avatar
Martin Reinecke committed
41
        """Only available if target is a scalar"""
Martin Reinecke's avatar
Martin Reinecke committed
42
43
        return self._metric

Martin Reinecke's avatar
Martin Reinecke committed
44
    def __getitem__(self, name):
Martin Reinecke's avatar
Martin Reinecke committed
45
        from .operators.simple_linear_operators import FieldAdapter
Martin Reinecke's avatar
Martin Reinecke committed
46
        return Linearization(self._val[name], FieldAdapter(self.domain, name))
Martin Reinecke's avatar
Martin Reinecke committed
47

Martin Reinecke's avatar
Martin Reinecke committed
48
    def __neg__(self):
Martin Reinecke's avatar
Martin Reinecke committed
49
        return Linearization(
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
50
51
            -self._val, -self._jac,
            None if self._metric is None else -self._metric)
Martin Reinecke's avatar
Martin Reinecke committed
52

Martin Reinecke's avatar
Martin Reinecke committed
53
54
55
56
57
58
59
60
61
    def conjugate(self):
        return Linearization(
            self._val.conjugate(), self._jac.conjugate(),
            None if self._metric is None else self._metric.conjugate())

    @property
    def real(self):
        return Linearization(self._val.real, self._jac.real)

Martin Reinecke's avatar
Martin Reinecke committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    def __add__(self, other):
        if isinstance(other, Linearization):
            from .operators.relaxed_sum_operator import RelaxedSumOperator
            met = None
            if self._metric is not None and other._metric is not None:
                met = RelaxedSumOperator((self._metric, other._metric))
            return Linearization(
                self._val.unite(other._val),
                RelaxedSumOperator((self._jac, other._jac)), met)
        if isinstance(other, (int, float, complex, Field, MultiField)):
            return Linearization(self._val+other, self._jac, self._metric)

    def __radd__(self, other):
        return self.__add__(other)

    def __sub__(self, other):
        return self.__add__(-other)

    def __rsub__(self, other):
        return (-self).__add__(other)

    def __mul__(self, other):
        from .sugar import makeOp
        if isinstance(other, Linearization):
Martin Reinecke's avatar
Martin Reinecke committed
86
87
            return Linearization(
                self._val*other._val,
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
88
89
90
91
92
93
                makeOp(other._val)(self._jac) + makeOp(self._val)(other._jac))
        if np.isscalar(other):
            if other == 1:
                return self
            met = None if self._metric is None else self._metric.scale(other)
            return Linearization(self._val*other, self._jac.scale(other), met)
Martin Reinecke's avatar
Martin Reinecke committed
94
        if isinstance(other, (Field, MultiField)):
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
95
            return Linearization(self._val*other, makeOp(other)(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
96
97

    def __rmul__(self, other):
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
98
        return self.__mul__(other)
Martin Reinecke's avatar
Martin Reinecke committed
99

Martin Reinecke's avatar
Martin Reinecke committed
100
101
    def vdot(self, other):
        from .domain_tuple import DomainTuple
Martin Reinecke's avatar
Martin Reinecke committed
102
        from .operators.simple_linear_operators import VdotOperator
Martin Reinecke's avatar
Martin Reinecke committed
103
104
        if isinstance(other, (Field, MultiField)):
            return Linearization(
Martin Reinecke's avatar
Martin Reinecke committed
105
                Field.scalar(self._val.vdot(other)),
Martin Reinecke's avatar
Martin Reinecke committed
106
                VdotOperator(other)(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
107
        return Linearization(
Martin Reinecke's avatar
Martin Reinecke committed
108
            Field.scalar(self._val.vdot(other._val)),
Martin Reinecke's avatar
Martin Reinecke committed
109
110
            VdotOperator(self._val)(other._jac) +
            VdotOperator(other._val)(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
111

Martin Reinecke's avatar
Martin Reinecke committed
112
    def sum(self):
Martin Reinecke's avatar
Martin Reinecke committed
113
        from .operators.simple_linear_operators import SumReductionOperator
Martin Reinecke's avatar
Martin Reinecke committed
114
        from .sugar import full
115
        return Linearization(
Martin Reinecke's avatar
Martin Reinecke committed
116
            Field.scalar(self._val.sum()),
Martin Reinecke's avatar
Martin Reinecke committed
117
            SumReductionOperator(self._jac.target)(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
118
119
120

    def exp(self):
        tmp = self._val.exp()
Martin Reinecke's avatar
Martin Reinecke committed
121
        return Linearization(tmp, makeOp(tmp)(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
122
123
124

    def log(self):
        tmp = self._val.log()
Martin Reinecke's avatar
Martin Reinecke committed
125
        return Linearization(tmp, makeOp(1./self._val)(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
126

Martin Reinecke's avatar
Martin Reinecke committed
127
128
    def tanh(self):
        tmp = self._val.tanh()
Martin Reinecke's avatar
Martin Reinecke committed
129
        return Linearization(tmp, makeOp(1.-tmp**2)(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
130
131
132
133

    def positive_tanh(self):
        tmp = self._val.tanh()
        tmp2 = 0.5*(1.+tmp)
Martin Reinecke's avatar
Martin Reinecke committed
134
        return Linearization(tmp2, makeOp(0.5*(1.-tmp**2))(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
135

Martin Reinecke's avatar
Martin Reinecke committed
136
137
138
139
140
141
142
143
144
145
    def add_metric(self, metric):
        return Linearization(self._val, self._jac, metric)

    @staticmethod
    def make_var(field):
        from .operators.scaling_operator import ScalingOperator
        return Linearization(field, ScalingOperator(1., field.domain))

    @staticmethod
    def make_const(field):
Martin Reinecke's avatar
Martin Reinecke committed
146
        from .operators.simple_linear_operators import NullOperator
Martin Reinecke's avatar
Martin Reinecke committed
147
        return Linearization(field, NullOperator({}, field.domain))