critical_filtering.py 7.43 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232

from nifty import *

import plotly.offline as pl
import plotly.graph_objs as go

from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.rank

np.random.seed(62)

class NonlinearResponse(LinearOperator):
    def __init__(self, FFT, Instrument, function, derivative, default_spaces=None):
        super(NonlinearResponse, self).__init__(default_spaces)
        self._domain = FFT.target
        self._target = Instrument.target
        self.function = function
        self.derivative = derivative
        self.I = Instrument
        self.FFT = FFT


    def _times(self, x, spaces=None):
        return self.I(self.function(self.FFT.adjoint_times(x)))

    def _adjoint_times(self, x, spaces=None):
        return self.FFT(self.function(self.I.adjoint_times(x)))

    def derived_times(self, x, position):
        position_derivative = self.derivative(self.FFT.adjoint_times(position))
        return self.I(position_derivative * self.FFT.adjoint_times(x))

    def derived_adjoint_times(self, x, position):
        position_derivative = self.derivative(self.FFT.adjoint_times(position))
        return self.FFT(position_derivative * self.I.adjoint_times(x))

    @property
    def domain(self):
        return self._domain

    @property
    def target(self):
        return self._target

    @property
    def unitary(self):
        return False

def plot_parameters(m,t,t_true, t_real):
    m = fft.adjoint_times(m)
    m_data = m.val.get_full_data().real
    t_data = t.val.get_full_data().real
    t_true_data = t_true.val.get_full_data().real
    t_real_data = t_real.val.get_full_data().real
    pl.plot([go.Heatmap(z=m_data)], filename='map.html')
    pl.plot([go.Scatter(y=t_data), go.Scatter(y=t_true_data),
             go.Scatter(y=t_real_data)], filename="t.html")

if __name__ == "__main__":

    distribution_strategy = 'not'

    # Set up position space
    s_space = RGSpace([128,128])
    # s_space = HPSpace(32)

    # Define harmonic transformation and associated harmonic space
    fft = FFTOperator(s_space)
    h_space = fft.target[0]

    # Setting up power space
    p_space = PowerSpace(h_space, logarithmic = True,
                         distribution_strategy=distribution_strategy)

    # Choosing the prior correlation structure and defining correlation operator
    pow_spec = (lambda k: (.05 / (k + 1) ** 3))
    S = create_power_operator(h_space, power_spectrum=pow_spec,
                              distribution_strategy=distribution_strategy)

    # Drawing a sample sh from the prior distribution in harmonic space
    sp = Field(p_space,  val=lambda z: pow_spec(z)**(1./2),
               distribution_strategy=distribution_strategy)
    sh = sp.power_synthesize(real_signal=True)


    # Choosing the measurement instrument
#    Instrument = SmoothingOperator(s_space, sigma=0.01)
    Instrument = DiagonalOperator(s_space, diagonal=1.)
#    Instrument._diagonal.val[200:400, 200:400] = 0

    #   Choosing nonlinearity

    # log-normal model:

    function = exp
    derivative = exp

    # tan-normal model

    # def function(x):
    #     return 0.5 * tanh(x) + 0.5
    # def derivative(x):
    #     return 0.5*(1 - tanh(x)**2)

    # no nonlinearity, Wiener Filter

    # def function(x):
    #     return x
    # def derivative(x):
    #     return 1

    # small quadratic pertubarion

    # def function(x):
    #     return 0.5*x**2 + x
    # def derivative(x):
    #     return x + 1

    # def function(x):
    #     return 0.9*x**4 +0.2*x**2 + x
    # def derivative(x):
    #     return 0.9*4*x**3 + 0.4*x +1
    #

    #Adding a harmonic transformation to the instrument
    R = NonlinearResponse(fft, Instrument, function, derivative)
    noise = .01
    N = DiagonalOperator(s_space, diagonal=noise, bare=True)
    n = Field.from_random(domain=s_space,
                          random_type='normal',
                          std=sqrt(noise),
                          mean=0)

    # Creating the mock data
    d = R(sh) + n
    realized_power = log(sh.power_analyze(logarithmic=p_space.config["logarithmic"])**2)
    d_data = d.val.get_full_data().real
    if rank == 0:
        pl.plot([go.Heatmap(z=d_data)], filename='data.html')

    # Choosing the minimization strategy

    def convergence_measure(a_energy, iteration): # returns current energy
        x = a_energy.value
        print (x, iteration)

    # minimizer1 = SteepestDescent(convergence_tolerance=0,
    #                            iteration_limit=50,
    #                            callback=convergence_measure)

    minimizer1 = RelaxedNewton(convergence_tolerance=0,
                              convergence_level=1,
                              iteration_limit=5,
                              callback=convergence_measure)
    # minimizer2 = RelaxedNewton(convergence_tolerance=0,
    #                           convergence_level=1,
    #                           iteration_limit=2,
    #                           callback=convergence_measure)
    #
    # minimizer1 = VL_BFGS(convergence_tolerance=0,
    #                    iteration_limit=5,
    #                    callback=convergence_measure,
    #                    max_history_length=3)



    # Setting starting position
    flat_power = Field(p_space,val=10e-8)
    m0 = flat_power.power_synthesize(real_signal=True)

    t0 = Field(p_space, val=log(1./(1+p_space.kindex)**2))
    # t0 = Field(p_space,val=-8)
    # t0 = log(sp.copy()**2)
    S0 = create_power_operator(h_space, power_spectrum=exp(t0),
                               distribution_strategy=distribution_strategy)



    for i in range(100):
        S0 = create_power_operator(h_space, power_spectrum=exp(t0),
                              distribution_strategy=distribution_strategy)

        # Initializing the  nonlinear Wiener Filter energy
        map_energy = NonlinearWienerFilterEnergy(position=m0, d=d, R=R, N=N, S=S0)
        # Minimization with chosen minimizer
        (map_energy, convergence) = minimizer1(map_energy)
        # Updating parameters for correlation structure reconstruction
        m0 = map_energy.position
        D0 = map_energy.curvature
        # Initializing the power energy with updated parameters
        power_energy = CriticalPowerEnergy(position=t0, m=m0, D=D0, sigma=10., samples=3)
        (power_energy, convergence) = minimizer1(power_energy)
        # Setting new power spectrum
        t0 = power_energy.position
        plot_parameters(m0,t0,log(sp**2),realized_power)

    # Transforming fields to position space for plotting

    ss = fft.adjoint_times(sh)
    m = fft.adjoint_times(map_energy.position)


    # Plotting

    d_data = d.val.get_full_data().real
    if rank == 0:
        pl.plot([go.Heatmap(z=d_data)], filename='data.html')

    tt_data = power_energy.position.val.get_full_data().real
    t_data = log(sp**2).val.get_full_data().real
    if rank == 0:
        pl.plot([go.Scatter(y=t_data),go.Scatter(y=tt_data)], filename="t.html")
    ss_data = ss.val.get_full_data().real
    if rank == 0:
        pl.plot([go.Heatmap(z=ss_data)], filename='ss.html')

    sh_data = sh.val.get_full_data().real
    if rank == 0:
        pl.plot([go.Heatmap(z=sh_data)], filename='sh.html')


    m_data = m.val.get_full_data().real
    if rank == 0:
        pl.plot([go.Heatmap(z=m_data)], filename='map.html')

    f_m_data = function(m).val.get_full_data().real
    if rank == 0:
        pl.plot([go.Heatmap(z=f_m_data)], filename='f_map.html')
    f_ss_data = function(ss).val.get_full_data().real
    if rank == 0:
        pl.plot([go.Heatmap(z=f_ss_data)], filename='f_ss.html')