critical_filtering.py 5.25 KB
Newer Older
Philipp Arras's avatar
Philipp Arras committed
1 2 3 4
import numpy as np
from nifty import (VL_BFGS, DiagonalOperator, FFTOperator, Field,
                   LinearOperator, PowerSpace, RelaxedNewton, RGSpace,
                   SteepestDescent, create_power_operator, exp, log, sqrt)
5
from nifty.library.critical_filter import CriticalPowerEnergy
Philipp Arras's avatar
Philipp Arras committed
6
from nifty.library.wiener_filter import WienerFilterEnergy
7

Philipp Arras's avatar
Philipp Arras committed
8 9
import plotly.graph_objs as go
import plotly.offline as pl
10
from mpi4py import MPI
Philipp Arras's avatar
Philipp Arras committed
11

12 13 14
comm = MPI.COMM_WORLD
rank = comm.rank

15
np.random.seed(42)
16

Jakob Knollmueller's avatar
Jakob Knollmueller committed
17

Philipp Arras's avatar
Philipp Arras committed
18
def plot_parameters(m, t, p, p_d):
Jakob Knollmueller's avatar
Jakob Knollmueller committed
19 20

    x = log(t.domain[0].kindex)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
21
    m = fft.adjoint_times(m)
22 23 24 25
    m = m.val.get_full_data().real
    t = t.val.get_full_data().real
    p = p.val.get_full_data().real
    p_d = p_d.val.get_full_data().real
Martin Reinecke's avatar
Martin Reinecke committed
26
    pl.plot([go.Heatmap(z=m)], filename='map.html', auto_open=False)
Philipp Arras's avatar
Philipp Arras committed
27
    pl.plot([go.Scatter(x=x, y=t), go.Scatter(x=x, y=p),
Martin Reinecke's avatar
Martin Reinecke committed
28
             go.Scatter(x=x, y=p_d)], filename="t.html", auto_open=False)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
29

Jakob Knollmueller's avatar
Jakob Knollmueller committed
30 31 32 33

class AdjointFFTResponse(LinearOperator):
    def __init__(self, FFT, R, default_spaces=None):
        super(AdjointFFTResponse, self).__init__(default_spaces)
34
        self._domain = FFT.target
Jakob Knollmueller's avatar
Jakob Knollmueller committed
35 36
        self._target = R.target
        self.R = R
37 38 39
        self.FFT = FFT

    def _times(self, x, spaces=None):
Jakob Knollmueller's avatar
Jakob Knollmueller committed
40
        return self.R(self.FFT.adjoint_times(x))
41 42

    def _adjoint_times(self, x, spaces=None):
Jakob Knollmueller's avatar
Jakob Knollmueller committed
43
        return self.FFT(self.R.adjoint_times(x))
Philipp Arras's avatar
Philipp Arras committed
44

45 46 47 48 49 50 51 52 53 54 55 56
    @property
    def domain(self):
        return self._domain

    @property
    def target(self):
        return self._target

    @property
    def unitary(self):
        return False

Philipp Arras's avatar
Philipp Arras committed
57

58 59 60 61 62
if __name__ == "__main__":

    distribution_strategy = 'not'

    # Set up position space
Philipp Arras's avatar
Philipp Arras committed
63
    s_space = RGSpace([128, 128])
64 65 66 67 68 69
    # s_space = HPSpace(32)

    # Define harmonic transformation and associated harmonic space
    fft = FFTOperator(s_space)
    h_space = fft.target[0]

Philipp Arras's avatar
Philipp Arras committed
70
    # Set up power space
71
    p_space = PowerSpace(h_space, logarithmic=True,
72
                         distribution_strategy=distribution_strategy)
73

Philipp Arras's avatar
Philipp Arras committed
74
    # Choose the prior correlation structure and defining correlation operator
75
    p_spec = (lambda k: (.5 / (k + 1) ** 3))
Jakob Knollmueller's avatar
Jakob Knollmueller committed
76
    S = create_power_operator(h_space, power_spectrum=p_spec,
77 78
                              distribution_strategy=distribution_strategy)

Philipp Arras's avatar
Philipp Arras committed
79
    # Draw a sample sh from the prior distribution in harmonic space
Jakob Knollmueller's avatar
Jakob Knollmueller committed
80
    sp = Field(p_space,  val=p_spec,
81 82 83
               distribution_strategy=distribution_strategy)
    sh = sp.power_synthesize(real_signal=True)

Philipp Arras's avatar
Philipp Arras committed
84
    # Choose the measurement instrument
Martin Reinecke's avatar
Martin Reinecke committed
85 86
    # Instrument = SmoothingOperator(s_space, sigma=0.01)
    Instrument = DiagonalOperator(s_space, diagonal=1.)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
87
    # Instrument._diagonal.val[200:400, 200:400] = 0
Philipp Arras's avatar
Philipp Arras committed
88
    # Instrument._diagonal.val[64:512-64, 64:512-64] = 0
89

Philipp Arras's avatar
Philipp Arras committed
90
    # Add a harmonic transformation to the instrument
Jakob Knollmueller's avatar
Jakob Knollmueller committed
91 92
    R = AdjointFFTResponse(fft, Instrument)

93
    noise = 1.
94 95 96 97 98 99
    N = DiagonalOperator(s_space, diagonal=noise, bare=True)
    n = Field.from_random(domain=s_space,
                          random_type='normal',
                          std=sqrt(noise),
                          mean=0)

Philipp Arras's avatar
Philipp Arras committed
100
    # Create mock data
101
    d = R(sh) + n
Jakob Knollmueller's avatar
Jakob Knollmueller committed
102

103 104
    # The information source
    j = R.adjoint_times(N.inverse_times(d))
Martin Reinecke's avatar
Martin Reinecke committed
105 106
    realized_power = log(sh.power_analyze(binbounds=p_space.binbounds))
    data_power = log(fft(d).power_analyze(binbounds=p_space.binbounds))
107
    d_data = d.val.get_full_data().real
Martin Reinecke's avatar
Martin Reinecke committed
108
    if rank == 0:
Martin Reinecke's avatar
Martin Reinecke committed
109
        pl.plot([go.Heatmap(z=d_data)], filename='data.html', auto_open=False)
110

Philipp Arras's avatar
Philipp Arras committed
111 112
    #  Minimization strategy
    def convergence_measure(a_energy, iteration):  # returns current energy
113
        x = a_energy.value
Philipp Arras's avatar
Philipp Arras committed
114
        print(x, iteration)
115

116
    minimizer1 = RelaxedNewton(convergence_tolerance=1e-4,
Philipp Arras's avatar
Philipp Arras committed
117 118 119
                               convergence_level=1,
                               iteration_limit=5,
                               callback=convergence_measure)
120
    minimizer2 = VL_BFGS(convergence_tolerance=1e-4,
Philipp Arras's avatar
Philipp Arras committed
121
                         convergence_level=1,
122
                         iteration_limit=20,
Philipp Arras's avatar
Philipp Arras committed
123 124
                         callback=convergence_measure,
                         max_history_length=20)
125 126
    minimizer3 = SteepestDescent(convergence_tolerance=1e-4,
                                 iteration_limit=100,
Philipp Arras's avatar
Philipp Arras committed
127
                                 callback=convergence_measure)
128

Philipp Arras's avatar
Philipp Arras committed
129 130
    # Set starting position
    flat_power = Field(p_space, val=1e-8)
131 132
    m0 = flat_power.power_synthesize(real_signal=True)

133 134
    t0 = Field(p_space, val=log(1./(1+p_space.kindex)**2))

135
    for i in range(500):
Jakob Knollmueller's avatar
Jakob Knollmueller committed
136
        S0 = create_power_operator(h_space, power_spectrum=exp(t0),
Philipp Arras's avatar
Philipp Arras committed
137
                                   distribution_strategy=distribution_strategy)
138

Philipp Arras's avatar
Philipp Arras committed
139
        # Initialize non-linear Wiener Filter energy
Martin Reinecke's avatar
Martin Reinecke committed
140
        map_energy = WienerFilterEnergy(position=m0, d=d, R=R, N=N, S=S0)
Philipp Arras's avatar
Philipp Arras committed
141
        # Solve the Wiener Filter analytically
142
        D0 = map_energy.curvature
143
        m0 = D0.inverse_times(j)
Philipp Arras's avatar
Philipp Arras committed
144 145 146
        # Initialize power energy with updated parameters
        power_energy = CriticalPowerEnergy(position=t0, m=m0, D=D0,
                                           smoothness_prior=10., samples=3)
147

Martin Reinecke's avatar
Martin Reinecke committed
148
        (power_energy, convergence) = minimizer2(power_energy)
149

Philipp Arras's avatar
Philipp Arras committed
150 151
        # Set new power spectrum
        t0.val = power_energy.position.val.real
152

Philipp Arras's avatar
Philipp Arras committed
153 154 155 156
        # Plot current estimate
        print(i)
        if i % 5 == 0:
            plot_parameters(m0, t0, log(sp), data_power)