diff --git a/syconn/cnn/TrainData.py b/syconn/cnn/TrainData.py index 3095d93c188116ec57c2d493a22db427cf6b8b6f..95becb4fcc158172d24609187ab4bb216bccde36 100755 --- a/syconn/cnn/TrainData.py +++ b/syconn/cnn/TrainData.py @@ -275,12 +275,12 @@ if elektronn3_avail: Uses the same data for train and valid set. """ def __init__(self, cv_val=None, **kwargs): - ssd_kwargs = dict(working_dir="/ssdscratch/songbird/j0251/j0251_72_seg_20210127_agglo2") + ssd_kwargs = dict(working_dir="cajal/nvmescratch/projects/data/songbird/j0251/jj0251_72_seg_20210127_agglo2_syn_20220811_celltypes_20230822") super().__init__(ssd_kwargs=ssd_kwargs, cv_val=cv_val, **kwargs) # load GT #assert self.train, "Other mode than 'train' is not implemented." - self.csv_p = "/wholebrain/songbird/j0251/groundtruth/celltypes/j0251_celltype_gt_v6_j0251_72_seg_20210127_agglo2_IDs.csv" + self.csv_p = "/cajal/nvmescratch/projects/songbird/j0251/groundtruth/celltypes/j0251_celltype_gt_v7_j0251_72_seg_20210127_agglo2_IDs.csv" #self.csv_p = "cajal/nvmescratch/users/arother/cnn_training/j0251_celltype_gt_short_test.csv" df = pandas.io.parsers.read_csv(self.csv_p, header=None, names=['ID', 'type']).values ssv_ids = df[:, 0].astype(np.uint64) @@ -288,7 +288,7 @@ if elektronn3_avail: ixs, cnt = np.unique(ssv_ids, return_counts=True) raise ValueError(f'Multi-usage of IDs! {ixs[cnt > 1]}') str_labels = df[:, 1] - ssv_labels = np.array([str2int_converter(el, gt_type='ctgt_j0251_v3') for el in str_labels], dtype=np.uint16) + ssv_labels = np.array([str2int_converter(el, gt_type='ctgt_j0251_v4') for el in str_labels], dtype=np.uint16) if self.cv_val is not None and self.cv_val != -1: assert self.cv_val < 10 kfold = StratifiedKFold(n_splits=10, shuffle=True, random_state=0) diff --git a/syconn/cnn/cnn_celltype_ptcnv_j0251.py b/syconn/cnn/cnn_celltype_ptcnv_j0251.py index 918dfbc0e31122884e447d7b7e3eecce56ecc0f1..b504c1b202f12f274c04e23d19430afe89eb9f0a 100755 --- a/syconn/cnn/cnn_celltype_ptcnv_j0251.py +++ b/syconn/cnn/cnn_celltype_ptcnv_j0251.py @@ -47,7 +47,6 @@ parser.add_argument( ) parser.add_argument('--cval', default=None, help='Cross-validation split indicator.', type=int) - args = parser.parse_args() # SET UP ENVIRONMENT # @@ -80,7 +79,7 @@ eval_nr = random_seed # number of repetition dr = 0.3 track_running_stats = False use_norm = 'gn' -num_classes = 15 +num_classes = 17 onehot = True act = 'relu' use_myelin = False @@ -123,6 +122,7 @@ print(f'Running on device: {device}') # set paths #save_root = "cajal/nvmescratch/projects/data/songbird_tmp/j0251/j0251_72_seg_20210127_agglo2_syn_20220811/celltype_training/221216_celltype_noval/" +save_root = '/cajal/nvmescratch/users/arother/cnn_training/231207_celltype_training_testval/' if save_root is None: save_root = '~/e3_training_convpoint/' save_root = os.path.expanduser(save_root) @@ -165,17 +165,19 @@ train_transform = clouds.Compose([clouds.RandomVariation((-40, 40), distr='norma clouds.RandomRotate(apply_flip=True), clouds.ElasticTransform(res=(40, 40, 40), sigma=6), clouds.RandomScale(distr_scale=0.1, distr='uniform')]) -valid_transform = clouds.Compose([clouds.Center(), clouds.Normalization(scale_norm)]) - train_ds = CellCloudDataJ0251(npoints=npoints, transform=train_transform, cv_val=cval, cellshape_only=cellshape_only, use_syntype=use_syntype, onehot=onehot, batch_size=batch_size, ctx_size=ctx, map_myelin=use_myelin) -valid_ds = CellCloudDataJ0251(npoints=npoints, transform=valid_transform, train=False, - cv_val=cval, cellshape_only=cellshape_only, - use_syntype=use_syntype, onehot=onehot, batch_size=batch_size, - ctx_size=ctx, map_myelin=use_myelin) -#valid_ds = None + +if cval is not None and cval != -1: + valid_transform = clouds.Compose([clouds.Center(), clouds.Normalization(scale_norm)]) + valid_ds = CellCloudDataJ0251(npoints=npoints, transform=valid_transform, train=False, + cv_val=cval, cellshape_only=cellshape_only, + use_syntype=use_syntype, onehot=onehot, batch_size=batch_size, + ctx_size=ctx, map_myelin=use_myelin) +else: + valid_ds = None # PREPARE AND START TRAINING # diff --git a/syconn/handler/prediction_pts.py b/syconn/handler/prediction_pts.py index 327168818bf52ddbf90c7a077c54d7f2b56a9eb8..a404c9e2e19c61eeaf1bb3cca17435f1ae124db3 100755 --- a/syconn/handler/prediction_pts.py +++ b/syconn/handler/prediction_pts.py @@ -58,16 +58,16 @@ hc_cache_gt = {} def init_hc_cache_gt(): print("initialising cache") - v6_gt = pd.read_csv( - "wholebrain/songbird/j0251/groundtruth/celltypes/j0251_celltype_gt_v6_j0251_72_seg_20210127_agglo2_IDs.csv", + v7_gt = pd.read_csv( + "cajal/nvmescratch/projects/songbird/j0251/groundtruth/celltypes/j0251_celltype_gt_v7_j0251_72_seg_20210127_agglo2_IDs.csv", names=["cellids", "celltype"]) - cellids = np.array(v6_gt["cellids"]) + cellids = np.array(v7_gt["cellids"]) for cellid in cellids: - hc = load_pkl2obj('cajal/nvmescratch/projects/data/songbird_tmp/j0251/j0251_72_seg_20210127_agglo2_syn_20220811/celltype_training/hybrid_clouds_gt/%i_hc.pkl' % cellid) + hc = load_pkl2obj('cajal/nvmescratch/users/arother/cnn_training/231202_hybrid_clouds_gt/%i_hc.pkl' % cellid) hc_cache_gt[cellid] = hc -#init_hc_cache_gt() +init_hc_cache_gt() # TODO: move to handler.basics def write_ply(fn, verts, colors): diff --git a/syconn/mp/batchjob_utils.py b/syconn/mp/batchjob_utils.py index 99784274c9b71e412a5b884a3c4e149783db6ecf..9d956fa6228b6374354e9c6ab66f702b494ae127 100755 --- a/syconn/mp/batchjob_utils.py +++ b/syconn/mp/batchjob_utils.py @@ -138,7 +138,7 @@ def batchjob_script(params: list, name: str, if not overwrite: raise FileExistsError(f'Batchjob folder already exists at "{batchjob_folder}". Please' f' make sure it is safe for deletion, then set overwrite=True') - if batchjob_folder == wd: + if batchjob_folder == global_params.wd: raise ValueError('The directory you want to delete is the whole working directory') shutil.rmtree(batchjob_folder, ignore_errors=True) batchjob_folder = batchjob_folder.rstrip('/')