wiener_filter_hamiltonian.py 4.34 KB
 theos committed Oct 24, 2016 1 2 `````` from nifty import * `````` theos committed Oct 25, 2016 3 `````` `````` Theo Steininger committed Feb 09, 2017 4 5 ``````import plotly.offline as pl import plotly.graph_objs as go `````` theos committed Oct 24, 2016 6 7 8 9 10 `````` from mpi4py import MPI comm = MPI.COMM_WORLD rank = comm.rank `````` Theo Steininger committed Feb 09, 2017 11 ``````np.random.seed(42) `````` theos committed Oct 24, 2016 12 `````` `````` theos committed Oct 25, 2016 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 ``````class WienerFilterEnergy(Energy): def __init__(self, position, D, j): # in principle not necessary, but useful in order to make the signature # explicit super(WienerFilterEnergy, self).__init__(position) self.D = D self.j = j def at(self, position): return self.__class__(position, D=self.D, j=self.j) @property def value(self): D_inv_x = self.D_inverse_x() H = 0.5 * D_inv_x.dot(self.position) - self.j.dot(self.position) return H.real @property def gradient(self): D_inv_x = self.D_inverse_x() g = D_inv_x - self.j return_g = g.copy_empty(dtype=np.float) return_g.val = g.val.real return return_g `````` Theo Steininger committed Feb 09, 2017 38 39 40 41 42 43 44 45 46 47 48 `````` @property def curvature(self): class Dummy(object): def __init__(self, x): self.x = x def inverse_times(self, *args, **kwargs): return self.x.times(*args, **kwargs) my_dummy = Dummy(self.D) return my_dummy `````` theos committed Oct 25, 2016 49 `````` @memo `````` theos committed Oct 25, 2016 50 51 52 53 `````` def D_inverse_x(self): return D.inverse_times(self.position) `````` theos committed Oct 24, 2016 54 55 ``````if __name__ == "__main__": `````` Martin Reinecke committed Apr 28, 2017 56 `````` distribution_strategy = 'not' `````` theos committed Oct 24, 2016 57 `````` `````` theos committed Oct 25, 2016 58 `````` # Set up spaces and fft transformation `````` Martin Reinecke committed Apr 28, 2017 59 `````` s_space = RGSpace([512, 512]) `````` theos committed Oct 24, 2016 60 61 62 63 `````` fft = FFTOperator(s_space) h_space = fft.target[0] p_space = PowerSpace(h_space, distribution_strategy=distribution_strategy) `````` theos committed Oct 25, 2016 64 `````` # create the field instances and power operator `````` theos committed Oct 24, 2016 65 66 67 68 69 70 71 72 73 `````` pow_spec = (lambda k: (42 / (k + 1) ** 3)) S = create_power_operator(h_space, power_spectrum=pow_spec, distribution_strategy=distribution_strategy) sp = Field(p_space, val=lambda z: pow_spec(z)**(1./2), distribution_strategy=distribution_strategy) sh = sp.power_synthesize(real_signal=True) ss = fft.inverse_times(sh) `````` theos committed Oct 25, 2016 74 `````` # model the measurement process `````` theos committed Oct 24, 2016 75 76 77 78 79 80 81 82 83 84 85 `````` R = SmoothingOperator(s_space, sigma=0.01) # R = DiagonalOperator(s_space, diagonal=1.) # R._diagonal.val[200:400, 200:400] = 0 signal_to_noise = 1 N = DiagonalOperator(s_space, diagonal=ss.var()/signal_to_noise, bare=True) n = Field.from_random(domain=s_space, random_type='normal', std=ss.std()/np.sqrt(signal_to_noise), mean=0) `````` theos committed Oct 25, 2016 86 `````` # create mock data `````` theos committed Oct 24, 2016 87 `````` d = R(ss) + n `````` theos committed Oct 25, 2016 88 89 `````` # set up reconstruction objects `````` theos committed Oct 24, 2016 90 91 92 `````` j = R.adjoint_times(N.inverse_times(d)) D = PropagatorOperator(S=S, N=N, R=R) `````` theos committed Oct 25, 2016 93 `````` def distance_measure(energy, iteration): `````` theos committed Oct 25, 2016 94 95 `````` x = energy.position print (iteration, ((x-ss).norm()/ss.norm()).real) `````` theos committed Oct 24, 2016 96 `````` `````` Theo Steininger committed Feb 09, 2017 97 98 99 100 101 102 103 ``````# minimizer = SteepestDescent(convergence_tolerance=0, # iteration_limit=50, # callback=distance_measure) minimizer = RelaxedNewton(convergence_tolerance=0, iteration_limit=2, callback=distance_measure) `````` theos committed Oct 24, 2016 104 `````` `````` Theo Steininger committed Feb 09, 2017 105 106 107 108 ``````# minimizer = VL_BFGS(convergence_tolerance=0, # iteration_limit=50, # callback=distance_measure, # max_history_length=3) `````` theos committed Oct 24, 2016 109 `````` `````` Theo Steininger committed May 10, 2017 110 `````` m0 = Field(s_space, val=1.) `````` theos committed Oct 24, 2016 111 `````` `````` theos committed Oct 25, 2016 112 113 114 115 `````` energy = WienerFilterEnergy(position=m0, D=D, j=j) (energy, convergence) = minimizer(energy) `````` Theo Steininger committed Feb 09, 2017 116 `````` m = energy.position `````` theos committed Oct 25, 2016 117 `````` `````` Theo Steininger committed Feb 09, 2017 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 `````` d_data = d.val.get_full_data().real if rank == 0: pl.plot([go.Heatmap(z=d_data)], filename='data.html') ss_data = ss.val.get_full_data().real if rank == 0: pl.plot([go.Heatmap(z=ss_data)], filename='ss.html') sh_data = sh.val.get_full_data().real if rank == 0: pl.plot([go.Heatmap(z=sh_data)], filename='sh.html') j_data = j.val.get_full_data().real if rank == 0: pl.plot([go.Heatmap(z=j_data)], filename='j.html') jabs_data = np.abs(j.val.get_full_data()) jphase_data = np.angle(j.val.get_full_data()) if rank == 0: pl.plot([go.Heatmap(z=jabs_data)], filename='j_abs.html') pl.plot([go.Heatmap(z=jphase_data)], filename='j_phase.html') m_data = m.val.get_full_data().real if rank == 0: pl.plot([go.Heatmap(z=m_data)], filename='map.html') `````` theos committed Oct 25, 2016 144 145 146 147 `````` # grad_data = grad.val.get_full_data().real # if rank == 0: # pl.plot([go.Heatmap(z=grad_data)], filename='grad.html')``````