Skip to content
Snippets Groups Projects
Commit 8ffc8a5e authored by Maximilian Kurthen's avatar Maximilian Kurthen
Browse files

fixed bug in benchmark script

parent a2ef4923
No related branches found
No related tags found
No related merge requests found
......@@ -16,11 +16,8 @@ NAME = args.name
FIRST_ID = args.first_id
LAST_ID = args.last_id
MODEL = args.model
N_BINS = args.nbins
NOISE_VAR = args.noise_var
ITERATION_LIMIT = args.iteration_limit
TOL_REL_GRADNORM = args.tol_rel_gradnorm
BENCHMARK = args.benchmark
VERBOSITY = args.verbosity
......@@ -35,15 +32,12 @@ if CONFIG is not None:
with open('./model_configurations.txt') as f:
configs = eval(f.read())
parameters = configs[CONFIG]
MODEL = parameters.get('model', MODEL)
N_BINS = parameters.get('nbins', N_BINS)
NOISE_VAR = parameters.get('noise_var', NOISE_VAR)
POWER_SPECTRUM_BETA_STR = parameters.get(
'power_spectrum_beta', POWER_SPECTRUM_BETA_STR)
POWER_SPECTRUM_F_STR = parameters.get(
'power_spectrum_f', POWER_SPECTRUM_F_STR)
ITERATION_LIMIT = parameters.get('iteration_limit', ITERATION_LIMIT)
TOL_REL_GRADNORM = parameters.get('tol_rel_gradnorm', TOL_REL_GRADNORM)
if LAST_ID is None:
LAST_ID = get_benchmark_default_length(BENCHMARK)
......@@ -55,14 +49,17 @@ print(
'power spectrum beta: {}\n'
'power spectrum f: {}\n'
'rho: {}\n'
'scale_max: {}\n'
'storing results with suffix {}'.format(
BENCHMARK, FIRST_ID, LAST_ID, N_BINS,
NOISE_VAR,
POWER_SPECTRUM_BETA_STR,
POWER_SPECTRUM_F_STR,
RHO,
SCALE_MAX,
NAME))
np.random.seed(1)
POWER_SPECTRUM_BETA = lambda q: eval(POWER_SPECTRUM_BETA_STR)
POWER_SPECTRUM_F = lambda q: eval(POWER_SPECTRUM_F_STR)
scale = (0, SCALE_MAX)
......@@ -75,7 +72,8 @@ if os.path.isfile(prediction_file):
prediction_file = './benchmark_predictions/{}_{}_{}.txt'.format(
BENCHMARK, NAME, c)
accuracy = 0
np.random.seed(1)
sum_of_weights = 0
weighted_correct = 0
......@@ -88,35 +86,19 @@ for i in range(FIRST_ID-1, LAST_ID):
scaler = MinMaxScaler(scale)
x, y = scaler.fit_transform(np.array((x, y)).T).T
minimizer = nifty5.RelaxedNewton(controller=nifty5.GradientNormController(
tol_rel_gradnorm=TOL_REL_GRADNORM,
iteration_limit=ITERATION_LIMIT,
convergence_level=5,
))
if MODEL == 1:
bcm = bayesian_causal_model.cause_model_shallow.CausalModelShallow(
N_bins=N_BINS,
noise_var=NOISE_VAR,
rho=RHO,
power_spectrum_beta=POWER_SPECTRUM_BETA,
power_spectrum_f=POWER_SPECTRUM_F,
)
elif MODEL == 2:
bcm = bayesian_causal_model_nifty.cause_model_shallow.CausalModelShallow(
N_bins=N_BINS,
noise_var=NOISE_VAR,
rho=RHO,
power_spectrum_beta=POWER_SPECTRUM_BETA,
power_spectrum_f=POWER_SPECTRUM_F,
minimizer=minimizer,
)
bcm = bayesian_causal_model.cause_model_shallow.CausalModelShallow(
N_bins=N_BINS,
noise_var=NOISE_VAR,
rho=RHO,
power_spectrum_beta=POWER_SPECTRUM_BETA,
power_spectrum_f=POWER_SPECTRUM_F,
)
bcm.set_data(x, y)
H1 = bcm.get_evidence(direction=1, verbosity=1)
H2 = bcm.get_evidence(direction=-1, verbosity=1)
predicted_direction = 1 if int(H1 < H2) else 0
H1 = bcm.get_evidence(direction=1, verbosity=VERBOSITY - 1)
H2 = bcm.get_evidence(direction=-1, verbosity=VERBOSITY - 1)
predicted_direction = 1 if int(H1 < H2) else -1
if predicted_direction == true_direction:
fore = colorama.Fore.GREEN
......@@ -126,18 +108,19 @@ for i in range(FIRST_ID-1, LAST_ID):
sum_of_weights += weight
accuracy = weighted_correct / sum_of_weights
print(
'dataset {}, {} true direction: {}, predicted direction {}\n'
'H1: {:.2e},\n H2: {:.2e},\n{}'
'accuracy so far: {:.2f}'.format(
i,
fore,
true_direction,
predicted_direction,
H1,
H2,
colorama.Style.RESET_ALL,
accuracy))
if VERBOSITY > 0:
print(
'dataset {}, {} true direction: {}, predicted direction {}\n'
'H1: {:.2e},\n H2: {:.2e},\n{}'
'accuracy so far: {:.2f}'.format(
i,
fore,
true_direction,
predicted_direction,
H1,
H2,
colorama.Style.RESET_ALL,
accuracy))
with open(prediction_file, 'a') as f:
f.write('{} {} {} {}\n'.format(i+1, predicted_direction, H1, H2))
......@@ -147,7 +130,6 @@ print('accuracy: {:.2f}'.format(accuracy))
benchmark_information = {
'benchmark': BENCHMARK,
'model': MODEL,
'n_bins': N_BINS,
'noise_var': NOISE_VAR,
'rho': RHO,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment