diff --git a/prepare_data.py b/prepare_data.py new file mode 100644 index 0000000000000000000000000000000000000000..42c8576ae921b4b8727d5a6568466e543b058959 --- /dev/null +++ b/prepare_data.py @@ -0,0 +1,17 @@ +# Used to prepare the example datasets for hyper-parameter optimisation + +import numpy as np + +np.random.seed(123) +import cmlkit + +data = cmlkit.load_dataset("nmd18_train") + +rest, train, test = cmlkit.utility.threeway_split(data.n, 800, 200) + +train = cmlkit.dataset.Subset.from_dataset(data, idx=train, name="nmd18_hpo_train") +print(train.n) +train.save(directory="data/cmlkit") +test = cmlkit.dataset.Subset.from_dataset(data, idx=test, name="nmd18_hpo_test") +print(test.n) +test.save(directory="data/cmlkit")