nonlinear_critical_filter.py 4.4 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1 2
import nifty4 as ift
from nifty4.library.nonlinearities import Exponential
Martin Reinecke's avatar
Martin Reinecke committed
3 4 5 6 7
import numpy as np
np.random.seed(42)


def adjust_zero_mode(m0, t0):
8
    mtmp = ift.dobj.to_global_data(m0.val)
Martin Reinecke's avatar
Martin Reinecke committed
9
    zero_position = len(m0.shape)*(0,)
10 11 12 13 14 15
    zero_mode = mtmp[zero_position]
    mtmp[zero_position] = zero_mode / abs(zero_mode)
    ttmp = ift.dobj.to_global_data(t0.val)
    ttmp[0] += 2 * np.log(abs(zero_mode))
    return (ift.Field(m0.domain, ift.dobj.from_global_data(mtmp)),
            ift.Field(t0.domain, ift.dobj.from_global_data(ttmp)))
Martin Reinecke's avatar
Martin Reinecke committed
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35

if __name__ == "__main__":

    noise_level = 1.
    p_spec = (lambda k: (1. / (k + 1) ** 2))

    # nonlinearity = Linear()
    nonlinearity = Exponential()
    # Set up position space
    # s_space = ift.RGSpace([1024])
    s_space = ift.HPSpace(32)

    # Define harmonic transformation and associated harmonic space
    FFT = ift.FFTOperator(s_space)
    h_space = FFT.target[0]

    # Setting up power space
    p_space = ift.PowerSpace(h_space,
                             binbounds=ift.PowerSpace.useful_binbounds(
                                 h_space, logarithmic=True))
36 37 38
    s_spec = ift.Field(p_space, val=1.)
    # Choosing the prior correlation structure and defining
    # correlation operator
Martin Reinecke's avatar
Martin Reinecke committed
39
    p = ift.PS_field(p_space, p_spec)
Martin Reinecke's avatar
Martin Reinecke committed
40 41 42 43
    log_p = ift.log(p)
    S = ift.create_power_operator(h_space, power_spectrum=s_spec)

    # Drawing a sample sh from the prior distribution in harmonic space
Martin Reinecke's avatar
Martin Reinecke committed
44
    sp = ift.Field(p_space, val=s_spec)
Martin Reinecke's avatar
Martin Reinecke committed
45 46 47 48 49 50
    sh = ift.power_synthesize(sp)

    # Choosing the measurement instrument
    # Instrument = SmoothingOperator(s_space, sigma=0.01)
    mask = np.ones(s_space.shape)
    mask[6000:8000] = 0.
51
    mask = ift.Field(s_space, val=ift.dobj.from_global_data(mask))
Martin Reinecke's avatar
Martin Reinecke committed
52 53

    MaskOperator = ift.DiagonalOperator(mask)
Martin Reinecke's avatar
Martin Reinecke committed
54 55
    InstrumentResponse = ift.ResponseOperator(s_space, sigma=[0.0],
                                              exposure=[1.0])
Martin Reinecke's avatar
Martin Reinecke committed
56 57
    MeasurementOperator = InstrumentResponse*MaskOperator

Martin Reinecke's avatar
Martin Reinecke committed
58 59 60 61
    d_space = MeasurementOperator.target

    noise_covariance = ift.Field(d_space, val=noise_level**2).weight()
    N = ift.DiagonalOperator(noise_covariance)
Martin Reinecke's avatar
Martin Reinecke committed
62 63
    n = ift.Field.from_random(domain=d_space, random_type='normal',
                              std=noise_level)
64 65
    Projection = ift.PowerProjectionOperator(domain=h_space,
                                             power_space=p_space)
Martin Reinecke's avatar
Martin Reinecke committed
66
    power = Projection.adjoint_times(ift.exp(0.5*log_p))
Martin Reinecke's avatar
Martin Reinecke committed
67
    # Creating the mock data
Martin Reinecke's avatar
Martin Reinecke committed
68
    true_sky = nonlinearity(FFT.adjoint_times(power*sh))
Martin Reinecke's avatar
Martin Reinecke committed
69 70
    d = MeasurementOperator(true_sky) + n

Martin Reinecke's avatar
Martin Reinecke committed
71
    m0 = ift.power_synthesize(ift.Field(p_space, val=1e-7))
72
    t0 = ift.Field(p_space, val=-4.)
Martin Reinecke's avatar
Martin Reinecke committed
73 74 75 76
    power0 = Projection.adjoint_times(ift.exp(0.5 * t0))

    IC1 = ift.GradientNormController(verbose=True, iteration_limit=100,
                                     tol_abs_gradnorm=1e-3)
77 78
    LS = ift.LineSearchStrongWolfe(c2=0.02)
    minimizer = ift.RelaxedNewton(IC1, line_searcher=LS)
Martin Reinecke's avatar
Martin Reinecke committed
79 80 81 82 83 84 85 86

    ICI = ift.GradientNormController(verbose=False, name="ICI",
                                     iteration_limit=500,
                                     tol_abs_gradnorm=1e-3)
    inverter = ift.ConjugateGradient(controller=ICI)

    for i in range(20):
        power0 = Projection.adjoint_times(ift.exp(0.5*t0))
87 88 89
        map0_energy = ift.library.NonlinearWienerFilterEnergy(
            m0, d, MeasurementOperator, nonlinearity, FFT, power0, N, S,
            inverter=inverter)
Martin Reinecke's avatar
Martin Reinecke committed
90 91 92 93 94 95 96 97 98

        # Minimization with chosen minimizer
        map0_energy, convergence = minimizer(map0_energy)
        m0 = map0_energy.position

        # Updating parameters for correlation structure reconstruction
        D0 = map0_energy.curvature

        # Initializing the power energy with updated parameters
99 100 101 102
        power0_energy = ift.library.NonlinearPowerEnergy(
            position=t0, d=d, N=N, m=m0, D=D0, FFT=FFT,
            Instrument=MeasurementOperator, nonlinearity=nonlinearity,
            Projection=Projection, sigma=1., samples=2, inverter=inverter)
Martin Reinecke's avatar
Martin Reinecke committed
103

Martin Reinecke's avatar
Martin Reinecke committed
104
        power0_energy = minimizer(power0_energy)[0]
Martin Reinecke's avatar
Martin Reinecke committed
105 106 107 108

        # Setting new power spectrum
        t0 = power0_energy.position

109 110 111
        # break degeneracy between amplitude and excitation by setting
        # excitation monopole to 1
        m0, t0 = adjust_zero_mode(m0, t0)
Martin Reinecke's avatar
Martin Reinecke committed
112 113

    ift.plotting.plot(true_sky)
114 115
    ift.plotting.plot(nonlinearity(FFT.adjoint_times(power0*m0)),
                      title='reconstructed_sky')
Martin Reinecke's avatar
Martin Reinecke committed
116
    ift.plotting.plot(MeasurementOperator.adjoint_times(d))