diff --git a/deepof/train_model.py b/deepof/train_model.py index c604c5f943a7d36a6993ffbd8ab31afe9b3d6fa8..0e5ae830b8072438f8128eb1ac989fc55187eb8e 100644 --- a/deepof/train_model.py +++ b/deepof/train_model.py @@ -130,6 +130,13 @@ parser.add_argument( default=10, type=int, ) +parser.add_argument( + "--output-path", + "-o", + help="Sets the base directory where to output results. Default is the current directory", + type=str, + default=".", +) parser.add_argument( "--overlap-loss", "-ol", @@ -214,6 +221,7 @@ kl_wu = args.kl_warmup logparam = args.logparam loss = args.loss mmd_wu = args.mmd_warmup +output_path = os.path.join(args.output_path) overlap_loss = args.overlap_loss pheno_class = float(args.phenotype_classifier) predictor = float(args.predictor) @@ -433,8 +441,11 @@ if not tune: ) gmvaep.save_weights( - "GMVAE_loss={}_encoding={}_run_{}_final_weights.h5".format( - loss, encoding_size, run + os.path.join( + output_path, + "GMVAE_loss={}_encoding={}_run_{}_final_weights.h5".format( + loss, encoding_size, run + ), ) )