wiener_filter_hamiltonian.py 4.35 KB
Newer Older
1 2

from nifty import *
3

4 5
import plotly.offline as pl
import plotly.graph_objs as go
6 7 8 9 10

from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.rank

11
np.random.seed(42)
12

13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
class WienerFilterEnergy(Energy):
    def __init__(self, position, D, j):
        # in principle not necessary, but useful in order to make the signature
        # explicit
        super(WienerFilterEnergy, self).__init__(position)
        self.D = D
        self.j = j

    def at(self, position):
        return self.__class__(position, D=self.D, j=self.j)

    @property
    def value(self):
        D_inv_x = self.D_inverse_x()
        H = 0.5 * D_inv_x.dot(self.position) - self.j.dot(self.position)
        return H.real

    @property
    def gradient(self):
        D_inv_x = self.D_inverse_x()
        g = D_inv_x - self.j
        return_g = g.copy_empty(dtype=np.float)
        return_g.val = g.val.real
        return return_g

38 39 40 41 42 43 44 45 46 47 48
    @property
    def curvature(self):
        class Dummy(object):
            def __init__(self, x):
                self.x = x
            def inverse_times(self, *args, **kwargs):
                return self.x.times(*args, **kwargs)
        my_dummy = Dummy(self.D)
        return my_dummy


49
    @memo
50 51 52 53
    def D_inverse_x(self):
        return D.inverse_times(self.position)


54 55
if __name__ == "__main__":

56
    distribution_strategy = 'not'
57

58
    # Set up spaces and fft transformation
59 60 61 62 63
    s_space = RGSpace([512, 512], dtype=np.float)
    fft = FFTOperator(s_space)
    h_space = fft.target[0]
    p_space = PowerSpace(h_space, distribution_strategy=distribution_strategy)

64
    # create the field instances and power operator
65 66 67 68 69 70 71 72 73
    pow_spec = (lambda k: (42 / (k + 1) ** 3))
    S = create_power_operator(h_space, power_spectrum=pow_spec,
                              distribution_strategy=distribution_strategy)

    sp = Field(p_space, val=lambda z: pow_spec(z)**(1./2),
               distribution_strategy=distribution_strategy)
    sh = sp.power_synthesize(real_signal=True)
    ss = fft.inverse_times(sh)

74
    # model the measurement process
75 76 77 78 79 80 81 82 83 84 85
    R = SmoothingOperator(s_space, sigma=0.01)
#    R = DiagonalOperator(s_space, diagonal=1.)
#    R._diagonal.val[200:400, 200:400] = 0

    signal_to_noise = 1
    N = DiagonalOperator(s_space, diagonal=ss.var()/signal_to_noise, bare=True)
    n = Field.from_random(domain=s_space,
                          random_type='normal',
                          std=ss.std()/np.sqrt(signal_to_noise),
                          mean=0)

86
    # create mock data
87
    d = R(ss) + n
88 89

    # set up reconstruction objects
90 91 92
    j = R.adjoint_times(N.inverse_times(d))
    D = PropagatorOperator(S=S, N=N, R=R)

93
    def distance_measure(energy, iteration):
94 95
        x = energy.position
        print (iteration, ((x-ss).norm()/ss.norm()).real)
96

97 98 99 100 101 102 103
#    minimizer = SteepestDescent(convergence_tolerance=0,
#                                iteration_limit=50,
#                                callback=distance_measure)

    minimizer = RelaxedNewton(convergence_tolerance=0,
                              iteration_limit=2,
                              callback=distance_measure)
104

105 106 107 108
#    minimizer = VL_BFGS(convergence_tolerance=0,
#                        iteration_limit=50,
#                        callback=distance_measure,
#                        max_history_length=3)
109 110 111

    m0 = Field(s_space, val=1)

112 113 114 115
    energy = WienerFilterEnergy(position=m0, D=D, j=j)

    (energy, convergence) = minimizer(energy)

116
    m = energy.position
117

118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
    d_data = d.val.get_full_data().real
    if rank == 0:
        pl.plot([go.Heatmap(z=d_data)], filename='data.html')


    ss_data = ss.val.get_full_data().real
    if rank == 0:
        pl.plot([go.Heatmap(z=ss_data)], filename='ss.html')

    sh_data = sh.val.get_full_data().real
    if rank == 0:
        pl.plot([go.Heatmap(z=sh_data)], filename='sh.html')

    j_data = j.val.get_full_data().real
    if rank == 0:
        pl.plot([go.Heatmap(z=j_data)], filename='j.html')

    jabs_data = np.abs(j.val.get_full_data())
    jphase_data = np.angle(j.val.get_full_data())
    if rank == 0:
        pl.plot([go.Heatmap(z=jabs_data)], filename='j_abs.html')
        pl.plot([go.Heatmap(z=jphase_data)], filename='j_phase.html')

    m_data = m.val.get_full_data().real
    if rank == 0:
        pl.plot([go.Heatmap(z=m_data)], filename='map.html')
144 145 146 147

#    grad_data = grad.val.get_full_data().real
#    if rank == 0:
#        pl.plot([go.Heatmap(z=grad_data)], filename='grad.html')