polynomial_fit.py 4.59 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
Philipp Arras's avatar
Philipp Arras committed
14
15
# Copyright(C) 2013-2020 Max-Planck-Society
# Author: Philipp Arras
Martin Reinecke's avatar
Martin Reinecke committed
16
17
18
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

19
20
21
import matplotlib.pyplot as plt
import numpy as np

Martin Reinecke's avatar
Martin Reinecke committed
22
import nifty7 as ift
23
24
25
26
27
28
29
30
31


def polynomial(coefficients, sampling_points):
    """Computes values of polynomial whose coefficients are stored in
    coefficients at sampling points. This is a quick version of the
    PolynomialResponse.

    Parameters
    ----------
32
    coefficients: Field
33
34
35
    sampling_points: Numpy array
    """

36
    if not (isinstance(coefficients, ift.Field)
37
38
            and isinstance(sampling_points, np.ndarray)):
        raise TypeError
Martin Reinecke's avatar
Martin Reinecke committed
39
    params = coefficients.val
40
41
42
43
44
45
46
    out = np.zeros_like(sampling_points)
    for ii in range(len(params)):
        out += params[ii] * sampling_points**ii
    return out


class PolynomialResponse(ift.LinearOperator):
Martin Reinecke's avatar
Martin Reinecke committed
47
48
    """Calculates values of a polynomial parameterized by input at sampling
    points.
49
50
51
52
53
54
55
56
57
58
59

    Parameters
    ----------
    domain: UnstructuredDomain
        The domain on which the coefficients of the polynomial are defined.
    sampling_points: Numpy array
        x-values of the sampling points.
    """

    def __init__(self, domain, sampling_points):
        if not (isinstance(domain, ift.UnstructuredDomain)
Philipp Arras's avatar
Philipp Arras committed
60
                and isinstance(sampling_points, np.ndarray)):
61
62
63
64
            raise TypeError
        self._domain = ift.DomainTuple.make(domain)
        tgt = ift.UnstructuredDomain(sampling_points.shape)
        self._target = ift.DomainTuple.make(tgt)
Martin Reinecke's avatar
Martin Reinecke committed
65
        self._capability = self.TIMES | self.ADJOINT_TIMES
66
67
68
69
70
71
72
73

        sh = (self.target.size, domain.size)
        self._mat = np.empty(sh)
        for d in range(domain.size):
            self._mat.T[d] = sampling_points**d

    def apply(self, x, mode):
        self._check_input(x, mode)
Martin Reinecke's avatar
Martin Reinecke committed
74
        val = x.val_rw()
75
76
77
78
79
80
        if mode == self.TIMES:
            # FIXME Use polynomial() here
            out = self._mat.dot(val)
        else:
            # FIXME Can this be optimized?
            out = self._mat.conj().T.dot(val)
Martin Reinecke's avatar
Martin Reinecke committed
81
        return ift.makeField(self._tgt(mode), out)
82
83


Philipp Arras's avatar
Philipp Arras committed
84
def main():
Philipp Arras's avatar
Fixes    
Philipp Arras committed
85
86
87
88
    # Generate some mock data
    N_params = 10
    N_samples = 100
    size = (12,)
Martin Reinecke's avatar
Martin Reinecke committed
89
    x = ift.random.current_rng().random(size) * 10
Philipp Arras's avatar
Fixes    
Philipp Arras committed
90
91
92
93
94
95
96
97
98
99
    y = np.sin(x**2) * x**3
    var = np.full_like(y, y.var() / 10)
    var[-2] *= 4
    var[5] /= 2
    y[5] -= 0

    # Set up minimization problem
    p_space = ift.UnstructuredDomain(N_params)
    params = ift.full(p_space, 0.)
    R = PolynomialResponse(p_space, x)
Philipp Arras's avatar
Philipp Arras committed
100
    ift.extra.check_linear_operator(R)
Philipp Arras's avatar
Fixes    
Philipp Arras committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116

    d_space = R.target
    d = ift.makeField(d_space, y)
    N = ift.DiagonalOperator(ift.makeField(d_space, var))

    IC = ift.DeltaEnergyController(tol_rel_deltaE=1e-12, iteration_limit=200)
    likelihood = ift.GaussianEnergy(d, N) @ R
    Ham = ift.StandardHamiltonian(likelihood, IC)
    H = ift.EnergyAdapter(params, Ham, want_metric=True)

    # Minimize
    minimizer = ift.NewtonCG(IC)
    H, _ = minimizer(H)

    # Draw posterior samples
    metric = Ham(ift.Linearization.make_var(H.position, want_metric=True)).metric
Philipp Arras's avatar
Fixup    
Philipp Arras committed
117
    samples = [metric.draw_sample(from_inverse=True) + H.position
Philipp Arras's avatar
Fixes    
Philipp Arras committed
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
               for _ in range(N_samples)]

    # Plotting
    plt.errorbar(x, y, np.sqrt(var), fmt='ko', label='Data with error bars')
    xmin, xmax = x.min(), x.max()
    xs = np.linspace(xmin, xmax, 100)

    sc = ift.StatCalculator()
    for ii in range(len(samples)):
        sc.add(samples[ii])
        ys = polynomial(samples[ii], xs)
        if ii == 0:
            plt.plot(xs, ys, 'k', alpha=.05, label='Posterior samples')
            continue
        plt.plot(xs, ys, 'k', alpha=.05)
    ys = polynomial(H.position, xs)
    plt.plot(xs, ys, 'r', linewidth=2., label='Interpolation')
    plt.legend()
    plt.savefig('fit.png')
    plt.close()

    # Print parameters
    mean = sc.mean.val
    sigma = np.sqrt(sc.var.val)
    for ii in range(len(mean)):
        print('Coefficient x**{}: {:.2E} +/- {:.2E}'.format(ii, mean[ii], sigma[ii]))
Philipp Arras's avatar
Philipp Arras committed
144
145
146
147


if __name__ == '__main__':
    main()