Skip to content
Snippets Groups Projects
Commit b3aa178b authored by Philipp Arras's avatar Philipp Arras
Browse files

Refactor critical filter part

parent c1a60431
No related branches found
No related tags found
No related merge requests found
from scripts.just_plot import *
import numpy as np
import nifty5 as ift
from scripts.generate_data import *
from scripts.just_plot import *
from scripts.responses import *
import nifty5 as ift
import numpy as np
np.random.seed(42)
position_space = ift.RGSpace([256, 256])
harmonic_space = position_space.get_default_codomain()
HT = ift.HarmonicTransformOperator(harmonic_space, target=position_space)
power_space = ift.PowerSpace(harmonic_space)
A = ift.SLAmplitude(
target=power_space, n_pix=64, a=10, k0=.2, sm=-4, sv=.6, im=-2, iv=2)
import numpy as np
import nifty5 as ift
from just_plot import *
from responses import *
from generate_data import *
import nifty5 as ift
import responses as resp
from generate_data import generate_gaussian_data
from just_plot import plot_prior_samples_2d, plot_reconstruction_2d
np.random.seed(42)
position_space = ift.RGSpace([256,256])
position_space = ift.RGSpace([256, 256])
harmonic_space = position_space.get_default_codomain()
HT = ift.HarmonicTransformOperator(harmonic_space, target=position_space)
power_space = ift.PowerSpace(harmonic_space)
# Set up an amplitude operator for the field
# We want to set up a model for the amplitude spectrum with some magic numbers
dct = {
'target': power_space,
'n_pix': 64, # 64 spectral bins
# Spectral smoothness (affects Gaussian process part)
# Smoothness of spectrum
'a': 10, # relatively high variance of spectral curvature
'k0': .2, # quefrency mode below which cepstrum flattens
# Power-law part of spectrum:
# Power-law part of spectrum
'sm': -4, # preferred power-law slope
'sv': .6, # low variance of power-law slope
'im': -2, # y-intercept mean, in-/decrease for more/less contrast
'iv': 2. # y-intercept variance
'im': -2, # y-intercept mean, in-/decrease for more/less contrast
'iv': 2. # y-intercept variance
}
A = ift.SLAmplitude(**dct)
correlated_field = ift.CorrelatedField(position_space, A)
# interactive plotting
# plotting correlated_field(ift.from_random('normal',correlated_field.target))
### SETTING UP SPECIFIC SCENARIO ####
R = checkerboard_response(position_space)
# R = ift.GeometryRemover(position_space)
signal = ift.CorrelatedField(position_space, A)
R = resp.checkerboard_response(position_space)
data_space = R.target
signal = correlated_field
signal_response = R(correlated_field)
signal_response = R @ signal
# Set up likelihood and draw data from the model
# Set up likelihood and generate data from the model
N = ift.ScalingOperator(0.1, data_space)
data, ground_truth = generate_gaussian_data(signal_response, N)
plot_prior_samples_2d(5, signal, R, correlated_field, A, 'gauss', N=N)
likelihood = ift.GaussianEnergy(mean=data,
inverse_covariance=N.inverse)(signal_response)
plot_prior_samples_2d(5, signal, R, signal, A, 'gauss', N=N)
likelihood = ift.GaussianEnergy(
mean=data, inverse_covariance=N.inverse)(signal_response)
#### SOLVING PROBLEM ####
# SOLVE INFERENCE PROBLEM
ic_sampling = ift.GradientNormController(iteration_limit=100)
ic_newton = ift.GradInfNormController(
name='Newton', tol=1e-6, iteration_limit=30)
......@@ -65,18 +52,13 @@ H = ift.StandardHamiltonian(likelihood, ic_sampling)
initial_mean = ift.MultiField.full(H.domain, 0.)
mean = initial_mean
# number of samples used to estimate the KL
N_samples = 5
# Draw new samples to approximate the KL five times
for i in range(10):
# Draw new samples and minimize KL
KL = ift.MetricGaussianKL(mean, H, N_samples)
# Draw five samples and minimize KL, iterate 10 times
for _ in range(10):
KL = ift.MetricGaussianKL(mean, H, 5)
KL, convergence = minimizer(KL)
mean = KL.position
# Draw posterior samples and plotting
# Draw posterior samples and plot
N_posterior_samples = 30
KL = ift.MetricGaussianKL(mean, H, N_posterior_samples)
plot_reconstruction_2d(data, ground_truth, KL, signal, R, A)
import nifty5 as ift
import numpy as np
def checkerboard_response(position_space):
# Checkerboard mask for 2D mode
'''Checkerboard mask for 2D mode'''
mask = np.ones(position_space.shape)
x, y = position_space.shape
for i in range(8):
for j in range(8):
if (i + j) % 2 == 0:
mask[i*x//8:(i + 1)*x//8, j*y//8:(j + 1)*y//8] = 0
mask = ift.from_global_data(position_space,mask)
mask = ift.from_global_data(position_space, mask)
return ift.MaskOperator(mask)
def exposure_response(position_space):
# Structured exposure for 2D mode
'''Structured exposure for 2D mode'''
x_shape, y_shape = position_space.shape
exposure = np.ones(position_space.shape)
......@@ -25,31 +27,24 @@ def exposure_response(position_space):
exposure[:, x_shape//2:x_shape*3//2] *= 3.
exposure = ift.Field.from_global_data(position_space, exposure)
E = ift.DiagonalOperator(exposure)
E = ift.makeOp(exposure)
G = ift.GeometryRemover(E.target)
return G @ E
def psf_response(position_space):
C = ift.HarmonicSmoothingOperator(position_space,0.01)
C = ift.HarmonicSmoothingOperator(position_space, 0.01)
G = ift.GeometryRemover(C.target)
return G @ C
def radial_tomography_response(position_space, lines_of_sight=100):
def radial_los(n_los):
starts = list(np.random.uniform(0, 1, (n_los, 2)).T)
ends = list(0.5 + 0 * np.random.uniform(0, 1, (n_los, 2)).T)
return starts, ends
LOS_starts, LOS_ends = radial_los(lines_of_sight)
R = ift.LOSResponse(position_space, starts=LOS_starts, ends=LOS_ends)
return R
def random_tomography_response(position_space, lines_of_sight=100):
def random_los(n_los):
starts = list(np.random.uniform(0, 1, (n_los, 2)).T)
ends = list(np.random.uniform(0, 1, (n_los, 2)).T)
return starts, ends
LOS_starts, LOS_ends = random_los(lines_of_sight)
R = ift.LOSResponse(position_space, starts=LOS_starts, ends=LOS_ends)
return R
def radial_tomography_response(position_space, lines_of_sight=100):
starts = list(np.random.uniform(0, 1, (lines_of_sight, 2)).T)
ends = list(0.5 + 0*np.random.uniform(0, 1, (lines_of_sight, 2)).T)
return ift.LOSResponse(position_space, starts=starts, ends=ends)
def random_tomography_response(position_space, lines_of_sight=100):
starts = list(np.random.uniform(0, 1, (lines_of_sight, 2)).T)
ends = list(np.random.uniform(0, 1, (lines_of_sight, 2)).T)
return ift.LOSResponse(position_space, starts=starts, ends=ends)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment