ensemble_likelihood.py 4.82 KB
Newer Older
1 2
# -*- coding: utf-8 -*-

Theo Steininger's avatar
Theo Steininger committed
3 4
import numpy as np

5
from nifty import DiagonalOperator, FieldArray, Field
6

7
from imagine.likelihoods.likelihood import Likelihood
8

9 10

class EnsembleLikelihood(Likelihood):
11
    def __init__(self, observable_name,  measured_data,
12
                 data_covariance_operator, profile=None):
13
        self.observable_name = observable_name
14
        self.measured_data = self._strip_data(measured_data)
Theo Steininger's avatar
Theo Steininger committed
15
        self.data_covariance_operator = data_covariance_operator
16

17 18 19 20
    def _strip_data(self, data):
        # if the first element in the domain tuple is a FieldArray we must
        # extract the data
        if isinstance(data.domain[0], FieldArray):
Theo Steininger's avatar
bugfix  
Theo Steininger committed
21 22 23 24
            data = Field(domain=data.domain[1:],
                         val=data.val.get_full_data()[0],
                         distribution_strategy='not')
        return data
25

26
    def __call__(self, observable):
27 28 29
        field = observable[self.observable_name]
        return self._process_simple_field(field,
                                          self.measured_data,
30
                                          self.data_covariance_operator)
31

32
    def _process_simple_field(self, observable, measured_data,
33
                              data_covariance_operator):
Theo Steininger's avatar
Theo Steininger committed
34 35 36 37 38
        # https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula#Generalization
        # B = A^{-1} + U U^dagger
        # A = data_covariance
        # B^{-1} c = (A_inv -
        #             A_inv U (I_k + U^dagger A_inv U)^{-1} U^dagger A_inv) c
39

Theo Steininger's avatar
Theo Steininger committed
40
        k = observable.shape[0]
41
        n = observable.shape[1]
Theo Steininger's avatar
Theo Steininger committed
42 43

        obs_val = observable.val.get_full_data()
44
        obs_mean = observable.ensemble_mean().val.get_full_data()
Theo Steininger's avatar
Theo Steininger committed
45

46
        u_val = obs_val - obs_mean
47 48

        # compute quantities for OAS estimator
49
        mu = np.vdot(u_val, u_val)/k/n
50 51
        self.logger.debug("mu: %f" % mu)

Theo Steininger's avatar
Theo Steininger committed
52
        alpha = (np.einsum(u_val, [0, 1], u_val, [2, 1])**2).sum()
53
        alpha /= k*2
54

55 56
        numerator = (1 - 2./n)*alpha + (mu*n)**2
        denominator = (k + 1 - 2./n) * (alpha - ((mu*n)**2)/n)
57 58 59 60

        if denominator == 0:
            rho = 1
        else:
61
            rho = np.min([1, numerator/denominator])
62

Theo Steininger's avatar
Theo Steininger committed
63 64
        self.logger.debug("rho: %f = %f / %f" % (rho, numerator, denominator))

65
        # rescale U half/half
66
        u_val *= np.sqrt(1-rho) / np.sqrt(k)
67 68 69 70 71 72

        # we assume that data_covariance_operator is a DiagonalOperator
        if not isinstance(data_covariance_operator, DiagonalOperator):
            raise TypeError("data_covariance_operator must be a NIFTY "
                            "DiagonalOperator.")

73 74 75 76 77 78
        A_diagonal_val = data_covariance_operator.diagonal(bare=False).val
        self.logger.info(('rho*mu', rho*mu,
                          'rho', rho,
                          'mu', mu,
                          'alhpa', alpha))
        A_diagonal_val += rho*mu
79

80
        a_u_val = u_val/A_diagonal_val
Theo Steininger's avatar
Theo Steininger committed
81 82 83 84

        # build middle-matrix (kxk)
        middle = (np.eye(k) +
                  np.einsum(u_val.conjugate(), [0, 1],
85
                            a_u_val, [2, 1]))
Theo Steininger's avatar
Theo Steininger committed
86
        middle = np.linalg.inv(middle)
87 88
        c = measured_data - obs_mean

89 90
        # If the data was incomplete, i.e. contains np.NANs, set those values
        # to zero.
Theo Steininger's avatar
bugfix  
Theo Steininger committed
91
        c.val.data = np.nan_to_num(c.val.data)
92 93 94 95 96 97 98 99 100 101 102
        # assuming that A == A^dagger, this can be shortend
        # a_c = A.inverse_times(c)
        # u_a_c = a_c.dot(U, spaces=1)
        # u_a_c = u_a_c.conjugate()

        # and: double conjugate shouldn't make a difference
        # u_a_c = c.conjugate().dot(a_u, spaces=1).conjugate()

        # Pure NIFTy is
        # u_a_c = c.dot(a_u, spaces=1)
        # u_a_c_val = u_a_c.val.get_full_data()
103 104
        c_val = c.val.get_full_data()
        u_a_c_val = np.einsum(c_val, [1], a_u_val, [0, 1])
105

106
        first_summand_val = c_val/A_diagonal_val
107 108 109
        second_summand_val = np.einsum(middle, [0, 1], u_a_c_val, [1])
        second_summand_val = np.einsum(a_u_val, [0, 1],
                                       second_summand_val, [0])
110 111 112
#        # second_summand_val *= -1
#        second_summand = first_summand.copy_empty()
#        second_summand.val = second_summand_val
113

114 115
        result_1 = np.vdot(c_val, first_summand_val)
        result_2 = -np.vdot(c_val, second_summand_val)
116 117 118 119 120 121 122 123 124 125 126 127 128

        # compute regularizing determinant of the covariance
        # det(A + UV^T) =  det(A) det(I + V^T A^-1 U)
        log_det_1 = np.sum(np.log(A_diagonal_val))
        (sign, log_det_2) = np.slogdet(middle)
        if sign < 0:
            self.logger.error("Negative determinant of covariance!")

        result = -0.5*(result_1 + result_2 + log_det_1 + log_det_2)

        self.logger.info("Calculated (%s): -(%g + %g + %g + %g) = %f" %
                         (self.observable_name,
                          result_1, result_2, log_det_1, log_det_2))
129 130
#        result_array[i] = result
#        total_result = result_array.mean()
131 132

        return result