Commit f54038f7 by Martin Reinecke

### cleanup

parent d7b80547
 ... @@ -25,51 +25,47 @@ def get_random_LOS(n_los): ... @@ -25,51 +25,47 @@ def get_random_LOS(n_los): ends = list(np.random.uniform(0, 1, (n_los, 2)).T) ends = list(np.random.uniform(0, 1, (n_los, 2)).T) return starts, ends return starts, ends if __name__ == '__main__': if __name__ == '__main__': # FIXME description of the tutorial # FIXME description of the tutorial np.random.seed(42) np.random.seed(42) position_space = ift.RGSpace([128, 128]) position_space = ift.RGSpace([128, 128]) # Setting up an amplitude model # Setting up an amplitude model A, amplitude_internals = ift.make_amplitude_model( A = ift.AmplitudeModel(position_space, 16, 1, 10, -4., 1, 0., 1.) position_space, 16, 1, 10, -4., 1, 0., 1.) dummy = ift.from_random('normal', A.domain) # Building the model for a correlated signal # Building the model for a correlated signal harmonic_space = position_space.get_default_codomain() harmonic_space = position_space.get_default_codomain() ht = ift.HarmonicTransformOperator(harmonic_space, position_space) ht = ift.HarmonicTransformOperator(harmonic_space, position_space) power_space = A.value.domain[0] power_space = A.target[0] power_distributor = ift.PowerDistributor(harmonic_space, power_space) power_distributor = ift.PowerDistributor(harmonic_space, power_space) position = ift.MultiField.from_dict( dummy = ift.Field.from_random('normal', harmonic_space) {'xi': ift.Field.from_random('normal', harmonic_space)}) xi = ift.Variable(position)['xi'] correlated_field = lambda inp: ht(power_distributor(A(inp))*inp["xi"]) Amp = power_distributor(A) correlated_field_h = Amp * xi correlated_field = ht(correlated_field_h) # alternatively to the block above one can do: # alternatively to the block above one can do: # correlated_field,_ = ift.make_correlated_field(position_space, A) # correlated_field,_ = ift.make_correlated_field(position_space, A) # apply some nonlinearity # apply some nonlinearity signal = ift.PointwisePositiveTanh(correlated_field) signal = lambda inp: correlated_field(inp).positive_tanh() # Building the Line of Sight response # Building the Line of Sight response LOS_starts, LOS_ends = get_random_LOS(100) LOS_starts, LOS_ends = get_random_LOS(100) R = ift.LOSResponse(position_space, starts=LOS_starts, R = ift.LOSResponse(position_space, starts=LOS_starts, ends=LOS_ends) ends=LOS_ends) # build signal response model and model likelihood # build signal response model and model likelihood signal_response = R(signal) signal_response = lambda inp: R(signal(inp)) # specify noise # specify noise data_space = R.target data_space = R.target noise = .001 noise = .001 N = ift.ScalingOperator(noise, data_space) N = ift.ScalingOperator(noise, data_space) # generate mock data # generate mock data MOCK_POSITION = ift.from_random('normal', signal.position.domain) domain = ift.MultiDomain.union((A.domain, ift.MultiDomain.make({'xi': harmonic_space}))) data = signal_response.at(MOCK_POSITION).value + N.draw_sample() MOCK_POSITION = ift.from_random('normal', domain) data = signal_response(MOCK_POSITION) + N.draw_sample() # set up model likelihood # set up model likelihood likelihood = ift.GaussianEnergy(signal_response, mean=data, covariance=N) likelihood = lambda inp: ift.GaussianEnergy(mean=data, covariance=N)(signal_response(inp)) # set up minimization and inversion schemes # set up minimization and inversion schemes ic_cg = ift.GradientNormController(iteration_limit=10) ic_cg = ift.GradientNormController(iteration_limit=10) ... @@ -80,40 +76,38 @@ if __name__ == '__main__': ... @@ -80,40 +76,38 @@ if __name__ == '__main__': # build model Hamiltonian # build model Hamiltonian H = ift.Hamiltonian(likelihood, ic_sampling) H = ift.Hamiltonian(likelihood, ic_sampling) INITIAL_POSITION = ift.from_random('normal', H.position.domain) INITIAL_POSITION = ift.from_random('normal', domain) position = INITIAL_POSITION position = INITIAL_POSITION ift.plot(signal.at(MOCK_POSITION).value, title='ground truth') ift.plot(signal(MOCK_POSITION), title='ground truth') ift.plot(R.adjoint_times(data), title='data') ift.plot(R.adjoint_times(data), title='data') ift.plot([A.at(MOCK_POSITION).value], title='power') ift.plot([A(MOCK_POSITION)], title='power') ift.plot_finish(nx=3, xsize=16, ysize=5, title="setup", name="setup.png") ift.plot_finish(nx=3, xsize=16, ysize=5, title="setup", name="setup.png") # number of samples used to estimate the KL # number of samples used to estimate the KL N_samples = 20 N_samples = 20 for i in range(2): for i in range(2): H = H.at(position) metric = H(ift.Linearization.make_var(position)).metric samples = [H.metric.draw_sample(from_inverse=True) samples = [metric.draw_sample(from_inverse=True) for _ in range(N_samples)] for _ in range(N_samples)] KL = ift.SampledKullbachLeiblerDivergence(H, samples) KL = ift.SampledKullbachLeiblerDivergence(H, samples) KL = ift.EnergyAdapter(position, KL) KL = KL.make_invertible(ic_cg) KL = KL.make_invertible(ic_cg) KL, convergence = minimizer(KL) KL, convergence = minimizer(KL) position = KL.position position = KL.position ift.plot(signal.at(position).value, title="reconstruction") ift.plot(signal(position), title="reconstruction") ift.plot([A(position), A(MOCK_POSITION)], title="power") ift.plot([A.at(position).value, A.at(MOCK_POSITION).value], title="power") ift.plot_finish(nx=2, xsize=12, ysize=6, title="loop", name="loop.png") ift.plot_finish(nx=2, xsize=12, ysize=6, title="loop", name="loop.png") sc = ift.StatCalculator() sc = ift.StatCalculator() for sample in samples: for sample in samples: sc.add(signal.at(sample+position).value) sc.add(signal(sample+position)) ift.plot(sc.mean, title="mean") ift.plot(sc.mean, title="mean") ift.plot(ift.sqrt(sc.var), title="std deviation") ift.plot(ift.sqrt(sc.var), title="std deviation") powers = [A.at(s+position).value for s in samples] powers = [A(s+position) for s in samples] ift.plot([A.at(position).value, A.at(MOCK_POSITION).value]+powers, ift.plot([A(position), A(MOCK_POSITION)]+powers, title="power") title="power") ift.plot_finish(nx=3, xsize=16, ysize=5, title="results", ift.plot_finish(nx=3, xsize=16, ysize=5, title="results", name="results.png") name="results.png")