wiener_filter_hamiltonian.py 4.35 KB
 Theo Steininger committed Oct 24, 2016 1 2 `````` from nifty import * `````` Theo Steininger committed Oct 25, 2016 3 `````` `````` Theo Steininger committed Feb 09, 2017 4 5 ``````import plotly.offline as pl import plotly.graph_objs as go `````` Theo Steininger committed Oct 24, 2016 6 7 8 9 10 `````` from mpi4py import MPI comm = MPI.COMM_WORLD rank = comm.rank `````` Theo Steininger committed Feb 09, 2017 11 ``````np.random.seed(42) `````` Theo Steininger committed Oct 24, 2016 12 `````` `````` Theo Steininger committed Oct 25, 2016 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 ``````class WienerFilterEnergy(Energy): def __init__(self, position, D, j): # in principle not necessary, but useful in order to make the signature # explicit super(WienerFilterEnergy, self).__init__(position) self.D = D self.j = j def at(self, position): return self.__class__(position, D=self.D, j=self.j) @property def value(self): D_inv_x = self.D_inverse_x() H = 0.5 * D_inv_x.dot(self.position) - self.j.dot(self.position) return H.real @property def gradient(self): D_inv_x = self.D_inverse_x() g = D_inv_x - self.j return_g = g.copy_empty(dtype=np.float) return_g.val = g.val.real return return_g `````` Theo Steininger committed Feb 09, 2017 38 39 40 41 42 43 44 45 46 47 48 `````` @property def curvature(self): class Dummy(object): def __init__(self, x): self.x = x def inverse_times(self, *args, **kwargs): return self.x.times(*args, **kwargs) my_dummy = Dummy(self.D) return my_dummy `````` Theo Steininger committed Oct 25, 2016 49 `````` @memo `````` Theo Steininger committed Oct 25, 2016 50 51 52 53 `````` def D_inverse_x(self): return D.inverse_times(self.position) `````` Theo Steininger committed Oct 24, 2016 54 55 ``````if __name__ == "__main__": `````` Pumpe, Daniel (dpumpe) committed Apr 10, 2017 56 `````` distribution_strategy = 'not' `````` Theo Steininger committed Oct 24, 2016 57 `````` `````` Theo Steininger committed Oct 25, 2016 58 `````` # Set up spaces and fft transformation `````` Theo Steininger committed Oct 24, 2016 59 60 61 62 63 `````` s_space = RGSpace([512, 512], dtype=np.float) fft = FFTOperator(s_space) h_space = fft.target[0] p_space = PowerSpace(h_space, distribution_strategy=distribution_strategy) `````` Theo Steininger committed Oct 25, 2016 64 `````` # create the field instances and power operator `````` Theo Steininger committed Oct 24, 2016 65 66 67 68 69 70 71 72 73 `````` pow_spec = (lambda k: (42 / (k + 1) ** 3)) S = create_power_operator(h_space, power_spectrum=pow_spec, distribution_strategy=distribution_strategy) sp = Field(p_space, val=lambda z: pow_spec(z)**(1./2), distribution_strategy=distribution_strategy) sh = sp.power_synthesize(real_signal=True) ss = fft.inverse_times(sh) `````` Theo Steininger committed Oct 25, 2016 74 `````` # model the measurement process `````` Theo Steininger committed Oct 24, 2016 75 76 77 78 79 80 81 82 83 84 85 `````` R = SmoothingOperator(s_space, sigma=0.01) # R = DiagonalOperator(s_space, diagonal=1.) # R._diagonal.val[200:400, 200:400] = 0 signal_to_noise = 1 N = DiagonalOperator(s_space, diagonal=ss.var()/signal_to_noise, bare=True) n = Field.from_random(domain=s_space, random_type='normal', std=ss.std()/np.sqrt(signal_to_noise), mean=0) `````` Theo Steininger committed Oct 25, 2016 86 `````` # create mock data `````` Theo Steininger committed Oct 24, 2016 87 `````` d = R(ss) + n `````` Theo Steininger committed Oct 25, 2016 88 89 `````` # set up reconstruction objects `````` Theo Steininger committed Oct 24, 2016 90 91 92 `````` j = R.adjoint_times(N.inverse_times(d)) D = PropagatorOperator(S=S, N=N, R=R) `````` Theo Steininger committed Oct 25, 2016 93 `````` def distance_measure(energy, iteration): `````` Theo Steininger committed Oct 25, 2016 94 95 `````` x = energy.position print (iteration, ((x-ss).norm()/ss.norm()).real) `````` Theo Steininger committed Oct 24, 2016 96 `````` `````` Theo Steininger committed Feb 09, 2017 97 98 99 100 101 102 103 ``````# minimizer = SteepestDescent(convergence_tolerance=0, # iteration_limit=50, # callback=distance_measure) minimizer = RelaxedNewton(convergence_tolerance=0, iteration_limit=2, callback=distance_measure) `````` Theo Steininger committed Oct 24, 2016 104 `````` `````` Theo Steininger committed Feb 09, 2017 105 106 107 108 ``````# minimizer = VL_BFGS(convergence_tolerance=0, # iteration_limit=50, # callback=distance_measure, # max_history_length=3) `````` Theo Steininger committed Oct 24, 2016 109 110 111 `````` m0 = Field(s_space, val=1) `````` Theo Steininger committed Oct 25, 2016 112 113 114 115 `````` energy = WienerFilterEnergy(position=m0, D=D, j=j) (energy, convergence) = minimizer(energy) `````` Theo Steininger committed Feb 09, 2017 116 `````` m = energy.position `````` Theo Steininger committed Oct 25, 2016 117 `````` `````` Theo Steininger committed Feb 09, 2017 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 `````` d_data = d.val.get_full_data().real if rank == 0: pl.plot([go.Heatmap(z=d_data)], filename='data.html') ss_data = ss.val.get_full_data().real if rank == 0: pl.plot([go.Heatmap(z=ss_data)], filename='ss.html') sh_data = sh.val.get_full_data().real if rank == 0: pl.plot([go.Heatmap(z=sh_data)], filename='sh.html') j_data = j.val.get_full_data().real if rank == 0: pl.plot([go.Heatmap(z=j_data)], filename='j.html') jabs_data = np.abs(j.val.get_full_data()) jphase_data = np.angle(j.val.get_full_data()) if rank == 0: pl.plot([go.Heatmap(z=jabs_data)], filename='j_abs.html') pl.plot([go.Heatmap(z=jphase_data)], filename='j_phase.html') m_data = m.val.get_full_data().real if rank == 0: pl.plot([go.Heatmap(z=m_data)], filename='map.html') `````` Theo Steininger committed Oct 25, 2016 144 145 146 147 `````` # grad_data = grad.val.get_full_data().real # if rank == 0: # pl.plot([go.Heatmap(z=grad_data)], filename='grad.html')``````