ensemble_likelihood.py 4.15 KB
Newer Older
1 2
# -*- coding: utf-8 -*-

Theo Steininger's avatar
Theo Steininger committed
3 4
import numpy as np

5
from nifty import Field
6

7
from imagine.likelihoods.likelihood import Likelihood
8

9 10

class EnsembleLikelihood(Likelihood):
11
    def __init__(self, observable_name,  measured_data,
12
                 data_covariance, profile=None):
13
        self.observable_name = observable_name
14
        self.measured_data = self._strip_data(measured_data)
15 16 17
        if isinstance(data_covariance, Field):
            data_covariance = data_covariance.val.get_full_data()
        self.data_covariance = data_covariance
18
        self.use_determinant = True
19 20

    def __call__(self, observable):
21 22 23
        field = observable[self.observable_name]
        return self._process_simple_field(field,
                                          self.measured_data,
24
                                          self.data_covariance)
25

26
    def _process_simple_field(self, observable, measured_data,
27
                              data_covariance):
Theo Steininger's avatar
Theo Steininger committed
28 29 30 31 32
        # https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula#Generalization
        # B = A^{-1} + U U^dagger
        # A = data_covariance
        # B^{-1} c = (A_inv -
        #             A_inv U (I_k + U^dagger A_inv U)^{-1} U^dagger A_inv) c
Theo Steininger's avatar
bug fix  
Theo Steininger committed
33
        data_covariance = data_covariance.copy()
Theo Steininger's avatar
Theo Steininger committed
34
        k = observable.shape[0]
35
        n = observable.shape[1]
Theo Steininger's avatar
Theo Steininger committed
36 37

        obs_val = observable.val.get_full_data()
38
        obs_mean = observable.ensemble_mean().val.get_full_data()
Theo Steininger's avatar
Theo Steininger committed
39

40
        U = obs_val - obs_mean
41 42
        U *= np.sqrt(n)

43
        # compute quantities for OAS estimator
44 45
        mu = np.vdot(U, U)/k/n
        alpha = (np.einsum(U, [0, 1], U, [2, 1])**2).sum()
46
        alpha /= k**2
47

48 49
        numerator = (1 - 2./n)*alpha + (mu*n)**2
        denominator = (k + 1 - 2./n) * (alpha - ((mu*n)**2)/n)
50 51 52 53

        if denominator == 0:
            rho = 1
        else:
54
            rho = np.min([1, numerator/denominator])
Theo Steininger's avatar
Theo Steininger committed
55 56
        self.logger.debug("rho: %f = %f / %f" % (rho, numerator, denominator))

57
        # rescale U half/half
58
        V = U * np.sqrt(1-rho) / np.sqrt(k)
59

60
        self.logger.info(('data_cov', np.mean(data_covariance),
Theo Steininger's avatar
Theo Steininger committed
61
                          'rho*mu', rho*mu,
62 63
                          'rho', rho,
                          'mu', mu,
64
                          'alpha', alpha))
65
        B = data_covariance + rho*mu
66

67
        V_B = V/B
Theo Steininger's avatar
Theo Steininger committed
68 69 70

        # build middle-matrix (kxk)
        middle = (np.eye(k) +
71 72
                  np.einsum(V.conjugate(), [0, 1],
                            V_B, [2, 1]))
Theo Steininger's avatar
Theo Steininger committed
73
        middle = np.linalg.inv(middle)
74 75
        c = measured_data - obs_mean

76 77
        # If the data was incomplete, i.e. contains np.NANs, set those values
        # to zero.
78
        c = np.nan_to_num(c)
79 80 81 82 83 84 85 86 87 88 89
        # assuming that A == A^dagger, this can be shortend
        # a_c = A.inverse_times(c)
        # u_a_c = a_c.dot(U, spaces=1)
        # u_a_c = u_a_c.conjugate()

        # and: double conjugate shouldn't make a difference
        # u_a_c = c.conjugate().dot(a_u, spaces=1).conjugate()

        # Pure NIFTy is
        # u_a_c = c.dot(a_u, spaces=1)
        # u_a_c_val = u_a_c.val.get_full_data()
90
        V_B_c = np.einsum(c, [1], V_B, [0, 1])
91

92 93 94
        first_summand_val = c/B
        second_summand_val = np.einsum(middle, [0, 1], V_B_c, [1])
        second_summand_val = np.einsum(V_B, [0, 1],
95
                                       second_summand_val, [0])
96 97 98
#        # second_summand_val *= -1
#        second_summand = first_summand.copy_empty()
#        second_summand.val = second_summand_val
99

100 101
        result_1 = np.vdot(c, first_summand_val)
        result_2 = -np.vdot(c, second_summand_val)
102 103 104

        # compute regularizing determinant of the covariance
        # det(A + UV^T) =  det(A) det(I + V^T A^-1 U)
105
        if self.use_determinant:
106 107
            log_det = np.sum(np.log(data_covariance +
                                    np.sum((obs_val-obs_mean)**2, axis=0)/k))/n
108
        else:
109
            log_det = 0.
110

111
        result = -0.5*(result_1 + result_2 + log_det)
112

113
        self.logger.info("Calculated (%s): -1/2(%g + %g + %g) = %g" %
114
                         (self.observable_name,
115
                          result_1, result_2, log_det, result))
116 117
#        result_array[i] = result
#        total_result = result_array.mean()
118 119

        return result