ensemble_likelihood.py 3.87 KB
Newer Older
1
2
# -*- coding: utf-8 -*-

Theo Steininger's avatar
Theo Steininger committed
3
4
import numpy as np

5
6
from imagine.likelihoods.likelihood import Likelihood
from imagine.create_ring_profile import create_ring_profile
7

8
9

class EnsembleLikelihood(Likelihood):
10
    def __init__(self, observable_name,  measured_data,
11
                 data_covariance_operator, profile=None):
12
        self.observable_name = observable_name
13
        self.measured_data = measured_data
Theo Steininger's avatar
Theo Steininger committed
14
        self.data_covariance_operator = data_covariance_operator
15
16
        if profile is None:
            profile = create_ring_profile(
17
                            self.measured_data.val.get_full_data())
18
        self.profile = profile
19
20

    def __call__(self, observable):
21
22
23
        field = observable[self.observable_name]
        return self._process_simple_field(field,
                                          self.measured_data,
24
25
                                          self.data_covariance_operator,
                                          self.profile)
26
27

    def _process_simple_field(self, field, measured_data,
28
                              data_covariance_operator, profile):
Theo Steininger's avatar
Theo Steininger committed
29
30
31
32
33
        # https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula#Generalization
        # B = A^{-1} + U U^dagger
        # A = data_covariance
        # B^{-1} c = (A_inv -
        #             A_inv U (I_k + U^dagger A_inv U)^{-1} U^dagger A_inv) c
34
35
        observable = field

Theo Steininger's avatar
Theo Steininger committed
36
37
        k = observable.shape[0]

38
        A = data_covariance_operator
Theo Steininger's avatar
Theo Steininger committed
39
        obs_val = observable.val.get_full_data()
40
        obs_mean = observable.ensemble_mean().get_full_data()
Theo Steininger's avatar
Theo Steininger committed
41

42
43
44
45
46
47
        # divide out profile
        obs_val /= profile
        obs_mean /= profile
        measured_data = measured_data / profile

        u_val = obs_val - obs_mean
Theo Steininger's avatar
Theo Steininger committed
48
49
50
51
52
53
54
55
56
57
        U = observable.copy_empty()
        U.val = u_val
        a_u = A.inverse_times(U, spaces=1)

        # build middle-matrix (kxk)
        a_u_val = a_u.val.get_full_data()
        middle = (np.eye(k) +
                  np.einsum(u_val.conjugate(), [0, 1],
                            a_u_val, [2, 1]))
        middle = np.linalg.inv(middle)
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#        result_array = np.zeros(k)
#        for i in xrange(k):
#           c = measured_data - obs_val[i]
        c = measured_data - obs_mean

        # assuming that A == A^dagger, this can be shortend
        # a_c = A.inverse_times(c)
        # u_a_c = a_c.dot(U, spaces=1)
        # u_a_c = u_a_c.conjugate()

        # and: double conjugate shouldn't make a difference
        # u_a_c = c.conjugate().dot(a_u, spaces=1).conjugate()

        # Pure NIFTy is
        # u_a_c = c.dot(a_u, spaces=1)
        # u_a_c_val = u_a_c.val.get_full_data()
        c_weighted_val = c.weight().val.get_full_data()
        u_a_c_val = np.einsum(c_weighted_val, [1], a_u_val, [0, 1])

        first_summand = A.inverse_times(c)
        self.logger.debug("Calculated first summand.")
        second_summand_val = np.einsum(middle, [0, 1], u_a_c_val, [1])
        self.logger.debug("Intermediate step.")
        second_summand_val = np.einsum(a_u_val, [0, 1],
                                       second_summand_val, [0])
        second_summand_val *= -1
        second_summand = first_summand.copy_empty()
        second_summand.val = second_summand_val

        result_1 = -c.dot(first_summand)
        result_2 = -c.dot(second_summand)
        result = result_1 + result_2
90
91
        self.logger.debug("Calculated: %f + %f = %f" %
                          (result_1, result_2, result))
92
93
94
#        result_array[i] = result
#        total_result = result_array.mean()
        total_result = result
95
96
97
98
99
100
101
102
103
        normalization = measured_data.dot(measured_data)
        normalized_total_result = total_result / normalization
        self.logger.info("Applied normalization for total result: "
                         "%f / %f = %f" %
                         (total_result,
                          normalization,
                          normalized_total_result))

        return normalized_total_result