diff --git a/source/preprocess.py b/source/preprocess.py index 9ca0b7162d8e3550b08c0f89fd1d700266feb7b5..ae3c3921d4523dc128290c35ae8f4f478c328156 100644 --- a/source/preprocess.py +++ b/source/preprocess.py @@ -534,13 +534,19 @@ class table_dict(dict): X_test = X_test * g.reshape(1, window_size, 1) if shuffle: - X_train = np.random.choice(X_train, X_train.shape[0], replace=False) - X_test = np.random.choice(X_test, X_test.shape[0], replace=False) + X_train = X_train[ + np.random.choice(X_train.shape[0], X_train.shape[0], replace=False) + ] + X_test = X_test[ + np.random.choice(X_test.shape[0], X_test.shape[0], replace=False) + ] return X_train, X_test if shuffle: - X_train = np.random.choice(X_train, X_train.shape[0], replace=False) + X_train = X_train[ + np.random.choice(X_train.shape[0], X_train.shape[0], replace=False) + ] return X_train