Commit 70543edb authored by Philipp's avatar Philipp
Browse files

update acknowledgements; refactor `syn_sign_ratio_celltype` to allow caching...

update acknowledgements; refactor `syn_sign_ratio_celltype` to allow caching of results; update cmn with multi-views
parent 03276c80
Pipeline #91883 passed with stage
in 2 minutes and 47 seconds
......@@ -44,6 +44,8 @@ for providing egl extension code to handle multi-gpu rendering on the
same machine. The original code snippet (under the Apache License 2.0)
used for our project can be found
[here](https://github.com/deepmind/dm_control/blob/30069ac11b60ee71acbd9159547d0bc334d63281/dm_control/_render/pyopengl/egl_ext.py).
SyConn uses the packages [zmesh](https://github.com/seung-lab/zmesh) for mesh and [kimimaro](https://github.com/seung-lab/kimimaro)
for skeleton generation implemented and developed in the Seung Lab.
Thanks to Julia Kuhl (see http://somedonkey.com/ for more beautiful
work) for designing and creating the logo and to Rangoli Saxena, Mariana
Shumliakivska, Josef Mark, Maria Kawula, Atul Mohite, Carl Constantin v. Wedemeyer,
......
# Copyright (c) 2016 - now
# Max-Planck-Institute of Neurobiology, Munich, Germany
# Authors: Philipp Schubert, Joergen Kornfeld
from syconn.reps.super_segmentation import *
from syconn.handler import log_main
from syconn import global_params
import numpy as np
import pandas
if __name__ == "__main__":
WD ="/wholebrain/songbird/j0126/areaxfs_v6/"
global_params.wd = WD
global_params.config['batch_proc_system'] = None
str2int_label = dict(STN=0, DA=1, MSN=2, LMAN=3, HVC=4, GP=5, TAN=6, GPe=5, INT=7, FS=8, GLIA=9)
str2int_label["GP "] = 5 # typo
csv_p = '/wholebrain/songbird/j0126/GT/celltype_gt/j0126_cell_type_gt_areax_fs6_v3.csv'
df = pandas.io.parsers.read_csv(csv_p, header=None, names=['ID', 'type']).values
ssv_ids = df[:, 0].astype(np.uint)
if len(np.unique(ssv_ids)) != len(ssv_ids):
raise ValueError('Multi-usage of IDs!')
str_labels = df[:, 1]
ssv_labels = np.array([str2int_label[el] for el in str_labels], dtype=np.uint16)
classes, c_cnts = np.unique(ssv_labels, return_counts=True)
if np.max(classes) > 7:
raise ValueError('')
log_main.setLevel(20) # This is INFO level (to filter copied file messages)
log_main.info('Successfully parsed "{}" with the following cell type class '
'distribution [labels, counts]: {}, {}'.format(csv_p, classes,
c_cnts))
log_main.info('Total #cells: {}'.format(np.sum(c_cnts)))
gt_version = "ctgt_v4"
new_ssd = SuperSegmentationDataset(working_dir=WD, version=gt_version)
# --------------------------------------------------------------------------
# TEST PREDICTIONS OF TRAIN AND VALIDATION DATA
from syconn.handler.prediction import get_celltype_model_e3
from syconn.proc.stats import cluster_summary, projection_tSNE, model_performance
from elektronn3.models.base import InferenceModel
from syconn.reps.super_segmentation import SuperSegmentationDataset, SuperSegmentationObject
import tqdm
np.set_printoptions(precision=4)
da_equals_tan = True
# --------------------------------------------------------------------------
# analysis of VALIDATION set
for m_name in ['celltype_GTv4_syntype_CV{}_adam_nbviews20_longRUN_2ratios_BIG_bs40_10fold_eval0',
'celltype_GTv4_syntype_CV{}_adam_nbviews20_longRUN_2ratios_BIG_bs40_10fold_eval1',
'celltype_GTv4_syntype_CV{}_adam_nbviews20_longRUN_2ratios_BIG_bs40_10fold_eval2']:
# CV1: valid dataset: split_dc['valid'], CV2: valid_dataset: split_dc['train']
# Perform train data set eval as counter check
gt_l = []
certainty = []
pred_l = []
pred_proba = []
pred_l_large = []
pred_proba_large = []
latent_morph_d = []
latent_morph_l = []
loaded_ssv_ids = []
# pbar = tqdm.tqdm(total=len(new_ssd.ssv_ids))
for cv in range(10):
split_dc = load_pkl2obj(path=new_ssd.path + "/{}_splitting_cv{}_10fold.pkl".format(
gt_version, cv))
ssv_ids = split_dc['valid']
loaded_ssv_ids.extend(ssv_ids)
pred_key_appendix2 = m_name.format(str(cv))
print('Loading cv-{}-data of model {}'.format(cv, pred_key_appendix2))
m_path = '/wholebrain/u/pschuber/e3_training_10fold_eval/' + pred_key_appendix2
m = InferenceModel(m_path, bs=80)
for ssv_id in ssv_ids:
ssv = new_ssd.get_super_segmentation_object(ssv_id)
# predict
ssv.nb_cpus = 20
ssv._view_caching = True
# ssv.predict_celltype_cnn(model=m_large, pred_key_appendix=pred_key_appendix1,
# model_tnet=m_tnet)
ssv.predict_celltype_cnn(model=m, pred_key_appendix=pred_key_appendix2,
view_props={"overwrite": False, 'use_syntype': True,
'nb_views': 20, 'da_equals_tan': da_equals_tan})
ssv.load_attr_dict()
curr_l = ssv.attr_dict["cellttype_gt"]
if da_equals_tan:
# adapt GT labels
if curr_l == 6: curr_l = 1 # TAN and DA are the same now
if curr_l == 7: curr_l = 6 # INT now has label 6
gt_l.append(curr_l)
# small FoV
pred_l.append(ssv.attr_dict["celltype_cnn_e3" + pred_key_appendix2])
preds_small = ssv.attr_dict["celltype_cnn_e3{}_probas".format(pred_key_appendix2)]
major_dec = np.zeros(preds_small.shape[1])
preds_small = np.argmax(preds_small, axis=1)
# For printing with all classes (in case da_equals_tan is True)
for ii in range(len(major_dec)):
major_dec[ii] = np.sum(preds_small == ii)
major_dec /= np.sum(major_dec)
pred_proba.append(major_dec)
if pred_l[-1] != gt_l[-1]:
print(f'{pred_l[-1]}\t{gt_l[-1]}\t{ssv.id}\t{major_dec}')
certainty.append(ssv.certainty_celltype("celltype_cnn_e3{}_probas".format(pred_key_appendix2)))
# pbar.update(1)
# # large FoV
# pred_l_large.append(ssv.attr_dict["celltype_cnn_e3" + pred_key_appendix1])
# probas_large = ssv.attr_dict["celltype_cnn_e3{}_probas".format(pred_key_appendix1)]
# preds_large = np.argmax(probas_large, axis=1)
# major_dec = np.zeros(10)
# for ii in range(len(major_dec)):
# major_dec[ii] = np.sum(preds_large == ii)
# major_dec /= np.sum(major_dec)
# pred_proba_large.append(major_dec)
# # morphology embedding
# latent_morph_d.append(ssv.attr_dict["latent_morph_ct" + pred_key_appendix2])
# latent_morph_l.append(len(latent_morph_d[-1]) * [gt_l[-1]])
assert set(loaded_ssv_ids) == set(new_ssd.ssv_ids.tolist())
# # WRITE OUT COMBINED RESULTS
# train_d = np.concatenate(latent_morph_d)
# train_l = np.concatenate(latent_morph_l)
# pred_proba_large = np.array(pred_proba_large)
pred_proba = np.array(pred_proba)
certainty = np.array(certainty)
gt_l = np.array(gt_l)
int2str_label = {v: k for k, v in str2int_label.items()}
dest_p = f"/wholebrain/scratch/pschuber/celltype_comparison_syntype/{m_name}_valid" \
f"{'DA_eq_TAN' if da_equals_tan else ''}/"
os.makedirs(dest_p, exist_ok=True)
target_names = [int2str_label[kk] for kk in range(8)]
# SET TAN AND DA TO THE SAME CLASS
if da_equals_tan:
target_names[1] = 'Modulatory'
target_names.remove('TAN')
# # large
# classes, c_cnts = np.unique(np.argmax(pred_proba_large, axis=1), return_counts=True)
# log_main.info('Successful prediction [large FoV] with the following cell type class '
# 'distribution [labels, counts]: {}, {}'.format(classes, c_cnts))
# model_performance(pred_proba_large, gt_l, dest_p n_labels=9,
# target_names=target_names, prefix="large_")
# standard
classes, c_cnts = np.unique(pred_l, return_counts=True)
log_main.info('Successful prediction [standard] with the following cell type class '
'distribution [labels, counts]: {}, {}'.format(classes, c_cnts))
perc_50 = np.percentile(certainty, 50)
model_performance(pred_proba[certainty > perc_50], gt_l[certainty > perc_50],
dest_p + '/upperhalf/', n_labels=7, target_names=target_names,
add_text=f'Percentile-50: {perc_50}')
model_performance(pred_proba[certainty <= perc_50], gt_l[certainty <= perc_50],
dest_p + '/lowerhalf/', n_labels=7, target_names=target_names,
add_text=f'Percentile-50: {perc_50}')
model_performance(pred_proba, gt_l, dest_p, n_labels=7,
target_names=target_names)
......@@ -34,7 +34,7 @@ if __name__ == "__main__":
ssv_labels = np.array([str2int_label[el] for el in str_labels], dtype=np.uint16)
classes, c_cnts = np.unique(ssv_labels, return_counts=True)
if np.max(classes) > 7:
raise ValueError('Now we got Glia or FS?!')
raise ValueError('')
log_main.setLevel(20) # This is INFO level (to filter copied file messages)
log_main.info('Successfully parsed "{}" with the following cell type class '
'distribution [labels, counts]: {}, {}'.format(csv_p, classes,
......
......@@ -224,6 +224,7 @@ if __name__ == '__main__':
state_dict_fname = 'state_dict.pth'
wd = "/ssdscratch/pschuber/songbird/j0251/rag_flat_Jan2019_v3/"
# TODO: update!
bbase_dir = '/wholebrain/scratch/pschuber/e3_trainings_convpoint_celltypes_j0251/OLD_4Dev2020/'
for ctx, npts in [[20000, 25000], [20000, 50000], [20000, 75000], [20000, 5000], [4000, 25000]]:
......
from syconn.handler import basics, training
from syconn.mp.batchjob_utils import batchjob_script
if __name__ == '__main__':
nfold = 10
params = []
cnn_script = '/wholebrain/u/pschuber/devel/SyConn/syconn/cnn/cnn_celltype_ptcnv_j0251.py'
for run in range(3):
base_dir = f'/wholebrain/scratch/pschuber/e3_trainings_cmn_celltypes_j0251/'
for cval in range(nfold):
params.append([cnn_script, dict(sr=f'{base_dir}/celltype_CV{cval}/', cval=cval, seed=run)])
params = list(basics.chunkify_successive(params, 2))
batchjob_script(params, 'launch_trainer', n_cores=20, additional_flags='--time=7-0 --qos=720h --gres=gpu:2',
disable_batchjob=False,
batchjob_folder=f'/wholebrain/scratch/pschuber/batchjobs/launch_trainer_celltypes_cmn_j0251/',
remove_jobfolder=False, overwrite=True)
......@@ -693,12 +693,15 @@ if elektronn3_avail:
class CelltypeViewsE3(Dataset):
"""
Wrapper method for CelltypeViews data loader.
Views need to be available. If `view_key` is specified, make sure they exist by running the appropriate
rendering for every SSV in the GT, e.g. ``ssv._render_rawviews(4)`` for 4 views per location.
"""
def __init__(
self,
train=True,
transform: Callable = Identity(),
use_syntype_scal=False,
is_j0251=False,
**kwargs
):
super().__init__()
......@@ -706,17 +709,22 @@ if elektronn3_avail:
self.use_syntype_scal = use_syntype_scal
self.transform = transform
# TODO: add gt paths to config
self.ctv = CelltypeViews(None, None, **kwargs)
if not is_j0251:
raise RuntimeError('This version is deprecated!')
self.ctv = CelltypeViews(None, None, **kwargs)
else:
self.ctv = CelltypeViewsJ0251(None, None, **kwargs)
def __getitem__(self, index):
if self.use_syntype_scal:
inp, target, syn_signs = self.ctv.getbatch_alternative(1, source='train' if self.train else 'valid')
inp, _ = self.transform(inp, None) # Do not flip target label ^.^
return inp[0], target.squeeze().astype(np.int), syn_signs[0].astype(np.float32) # target should just be a scalar
# target should just be a scalar
return {'inp': (inp[0], syn_signs[0].astype(np.float32)), 'target': target.squeeze().astype(np.int)}
else:
inp, target = self.ctv.getbatch_alternative_noscal(1, source='train' if self.train else 'valid')
inp, _ = self.transform(inp, None) # Do not flip target label ^.^
return inp[0], target.squeeze().astype(np.int)
return {'inp': inp[0], 'target': target.squeeze().astype(np.int)}
def __len__(self):
"""Determines epoch size(s)"""
......@@ -1018,7 +1026,9 @@ class Data(object):
class MultiViewData(Data):
def __init__(self, working_dir, gt_type, nb_cpus=20,
label_dict=None, view_kwargs=None, naive_norm=True,
load_data=True, train_fraction=None, random_seed=0):
load_data=True, train_fraction=None, random_seed=0,
splitting_dict=None):
self.splitting_dict = splitting_dict
if view_kwargs is None:
view_kwargs = dict(raw_only=False,
nb_cpus=nb_cpus, ignore_missing=True,
......@@ -1028,7 +1038,7 @@ class MultiViewData(Data):
label_dc_path = self.gt_dir + "%s_labels.pkl" % gt_type
if label_dict is None:
self.label_dict = load_pkl2obj(label_dc_path)
if not os.path.isfile(splitting_dc_path) or train_fraction is \
if (not os.path.isfile(splitting_dc_path) and self.splitting_dict is None) or train_fraction is \
not None:
if train_fraction is None:
msg = f'Did not find splitting dictionary at {splitting_dc_path} ' \
......@@ -1064,7 +1074,8 @@ class MultiViewData(Data):
else:
if train_fraction is not None:
raise ValueError('Value fraction can only be set if splitting dict is not available.')
self.splitting_dict = load_pkl2obj(splitting_dc_path)
if self.splitting_dict is None:
self.splitting_dict = load_pkl2obj(splitting_dc_path)
self.ssd = SuperSegmentationDataset(working_dir, version=gt_type)
if not load_data:
......@@ -1204,9 +1215,8 @@ class CelltypeViews(MultiViewData):
self.nb_cpus = nb_cpus
self.raw_only = raw_only
self.reduce_context = reduce_context
self.cache_size = 4000 * 2 # random permutations/subset in selected SSV views,
# RandomFlip augmentation etc.
self.max_nb_cache_uses = self.cache_size
self.max_nb_cache_uses = 4000 * 2
self.current_cache_uses = 0
assert n_classes == len(class_weights)
self.n_classes = n_classes
......@@ -1228,7 +1238,7 @@ class CelltypeViews(MultiViewData):
k, v in self.splitting_dict.items()]
now_splits = [(k, np.unique(v, return_counts=False)) for
k, v in splitting_dict.items()]
log_cnn.critical('Splitting dict was passed explicitely. Overwriting '
log_cnn.critical('Splitting dict was passed explicitly. Overwriting '
'default splitting of super-class. Support '
f'previous: {prev_splits}'
f'Support now: {now_splits}.')
......@@ -1274,15 +1284,17 @@ class CelltypeViews(MultiViewData):
self.valid_l = self.valid_l[ixs]
# NOTE: also performs 'naive_view_normalization'
if self.view_cache[source] is None or self.current_cache_uses == self.max_nb_cache_uses:
sample_fac = np.max([int(self.nb_views / 10), 2]) # draw more ssv if number of views
sample_fac = np.max([int(self.nb_views / 20), 1]) # draw more ssv if number of views
# is high
nb_ssv = self.n_classes * sample_fac
sample_ixs = []
l = []
labels2draw = np.arange(self.n_classes)
np.random.shuffle(labels2draw) # change order
for i in labels2draw:
for cnt, i in enumerate(labels2draw):
curr_nb_samples = max(nb_ssv // self.n_classes * self.class_weights[i], 1)
if source == 'valid' and cnt > 2:
break
try:
if source == "train":
sample_ixs.append(np.random.choice(self.train_d[self.train_l == i],
......@@ -1309,6 +1321,7 @@ class CelltypeViews(MultiViewData):
ssos.append(sso)
self.view_cache[source] = [sso.load_views(view_key=self.view_key) for sso in ssos]
# pre- and postsynapse type ratios
start = time.time()
self.syn_sign_cache[source] = np.array(
[[syn_sign_ratio_celltype(sso), syn_sign_ratio_celltype(sso, comp_types=[0, ])]
for sso in ssos])
......@@ -1437,13 +1450,14 @@ class CelltypeViews(MultiViewData):
if self.view_cache[source] is None or self.current_cache_uses == self.max_nb_cache_uses:
sample_fac = np.max([int(self.nb_views / 20), 1]) # draw more ssv if number of views is high
nb_ssv = self.n_classes * sample_fac # 1 for each class
sample_ixs = []
l = []
labels2draw = np.arange(self.n_classes)
class_sample_weight = self.class_weights
np.random.shuffle(labels2draw) # change order
for i in labels2draw:
curr_nb_samples = nb_ssv // self.n_classes * class_sample_weight[i] # sample more EA and MSN
curr_nb_samples = nb_ssv // self.n_classes * class_sample_weight[i]
try:
if source == "train":
sample_ixs.append(np.random.choice(self.train_d[self.train_l == i],
......@@ -1495,6 +1509,107 @@ class CelltypeViews(MultiViewData):
return tuple([d, l])
class CelltypeViewsJ0251(CelltypeViews):
def __init__(self, inp_node, out_node, raw_only=False, nb_views=20, nb_views_renderinglocations=2,
reduce_context=0, binary_views=False, reduce_context_fact=1, n_classes=4,
class_weights=(2, 2, 1, 1), load_data=False, nb_cpus=1,
random_seed=0, view_key=None, cv_val=None):
"""
USES NAIVE_VIEW_NORMALIZATION_NEW, i.e. `/ 255. - 0.5`
Parameters
----------
inp_node :
out_node :
raw_only :
nb_views : int
Number of sampled views used for prediction of cell type
nb_views_renderinglocations : int
Number of views per rendering location
reduce_context :
binary_views :
reduce_context_fact :
load_data :
nb_cpus :
view_key : str
"""
global_params.wd = "/ssdscratch/pschuber/songbird/j0251/rag_flat_Jan2019_v3/"
ctgt_key = None # use standard ssv store
assert "rag_flat_Jan2019_v3" in global_params.config.working_dir
assert os.path.isdir(global_params.config.working_dir)
if view_key is None:
self.view_key = "raw{}".format(nb_views_renderinglocations)
else:
self.view_key = view_key
self.nb_views = nb_views
self.nb_cpus = nb_cpus
self.raw_only = raw_only
self.reduce_context = reduce_context
# RandomFlip augmentation etc.
self.max_nb_cache_uses = 4000 * 2
self.current_cache_uses = 0
assert n_classes == len(class_weights)
self.n_classes = n_classes
self.class_weights = np.array(class_weights)
self.view_cache = {'train': None, 'valid': None, 'test': None}
self.label_cache = {'train': None, 'valid': None, 'test': None}
self.syn_sign_cache = {'train': None, 'valid': None, 'test': None}
self.sample_weights = {'train': None, 'valid': None, 'test': None}
self.reduce_context_fact = reduce_context_fact
self.binary_views = binary_views
self.example_shape = (nb_views, 4, 2, 128, 256)
self.cv_val = cv_val
# load GT
self.csv_p = "/wholebrain/songbird/j0251/groundtruth/celltypes/j0251_celltype_gt_v4.csv"
df = pandas.io.parsers.read_csv(self.csv_p, header=None, names=['ID', 'type']).values
ssv_ids = df[:, 0].astype(np.uint)
if len(np.unique(ssv_ids)) != len(ssv_ids):
ixs, cnt = np.unique(ssv_ids, return_counts=True)
raise ValueError(f'Multi-usage of IDs! {ixs[cnt > 1]}')
str_labels = df[:, 1]
ssv_labels = np.array([str2int_converter(el, gt_type='ctgt_j0251_v2') for el in str_labels], dtype=np.uint16)
if self.cv_val is not None and self.cv_val != -1:
assert self.cv_val < 10
kfold = StratifiedKFold(n_splits=10, shuffle=True, random_state=0)
for ii, (train_ixs, test_ixs) in enumerate(kfold.split(ssv_ids, y=ssv_labels)):
if ii == self.cv_val:
self.splitting_dict = {'train': ssv_ids[train_ixs], 'valid': ssv_ids[test_ixs]}
else:
self.splitting_dict = {'train': ssv_ids, 'valid': ssv_ids} # use all data
log_cnn.critical(f'Using all GT data for training!')
self.label_dict = {k: v for k, v in zip(ssv_ids, ssv_labels)}
self.sso_ids = self.splitting_dict['train']
for k, v in self.splitting_dict.items():
classes, c_cnts = np.unique([self.label_dict[ix] for ix in
self.splitting_dict[k]], return_counts=True)
log_cnn.debug(f"{k} [labels, counts]: {classes}, {c_cnts}")
log_cnn.debug(f'{len(self.sso_ids)} SSV IDs in training set: {self.sso_ids}')
dc_split_prev = dict(self.splitting_dict)
dc_label_prev = dict(self.label_dict)
print("Initializing CelltypeViewsJ0251:", self.__dict__) # TODO: add gt paths to config
super(CelltypeViews, self).__init__(global_params.config.working_dir, ctgt_key, train_fraction=None,
naive_norm=False, load_data=load_data, random_seed=random_seed,
splitting_dict=dc_split_prev, label_dict=dc_label_prev)
# check that super left dicts unmodified
assert self.splitting_dict == dc_split_prev
assert self.label_dict == dc_label_prev
self.train_d = np.array(self.splitting_dict["train"])
self.valid_d = np.array(self.splitting_dict["valid"])
ssv_gt_dict = self.label_dict
self.train_l = np.array([ssv_gt_dict[ix] for ix in self.train_d], np.int16)[:, None]
self.valid_l = np.array([ssv_gt_dict[ix] for ix in self.valid_d], np.int16)[:, None]
self.train_d = self.train_d[:, None]
self.valid_d = self.valid_d[:, None]
super(MultiViewData, self).__init__()
for k, v in self.splitting_dict.items():
classes, c_cnts = np.unique([self.label_dict[ix] for ix in
self.splitting_dict[k]], return_counts=True)
print(f"{k} [labels, counts]: {classes}, {c_cnts}")
class GliaViews(Data):
def __init__(self, inp_node, out_node, raw_only=True, nb_views=2,
reduce_context=0, binary_views=False, reduce_context_fact=1,
......
......@@ -4,12 +4,6 @@
# Max Planck Institute of Neurobiology, Munich, Germany
# Authors: Philipp Schubert
"""
Workflow of spinal semantic segmentation based on multiviews (2D semantic segmentation).
It learns how to differentiate between spine head, spine neck and spine shaft.
Caution! The input dataset was not manually corrected.
"""
from syconn import global_params
from syconn.cnn.TrainData import CelltypeViewsE3
import argparse
......@@ -262,14 +256,13 @@ if __name__ == "__main__":
device=device,
train_dataset=train_dataset,
valid_dataset=valid_dataset,
batchsize=batch_size,
batch_size=batch_size,
num_workers=0,
save_root=save_root,
exp_name=args.exp_name,
schedulers=schedulers,
valid_metrics=valid_metrics,
ipython_shell=False,
mixed_precision=False, # Enable to use Apex for mixed precision training
)
# Archiving training script, src folder, env info
......
#!/usr/bin/env python3
# Copyright (c) 2017 - now
# Max Planck Institute of Neurobiology, Munich, Germany
# Authors: Philipp Schubert
"""
"""
from syconn import global_params
from syconn.cnn.TrainData import CelltypeViewsE3
import argparse
import _pickle
import zipfile
import numpy as np
import os
import torch
from torch import nn
from torch import optim
from elektronn3.models.simple import Conv3DLayer, StackedConv2Scalar
from elektronn3.data.transforms import RandomFlip
from elektronn3.data import transforms
class StackedConv2ScalarWithLatentAdd(nn.Module):
def __init__(self, in_channels, n_classes, dropout_rate=0.08, act='relu',
n_scalar=1):
super().__init__()
if act == 'relu':
act = nn.ReLU()
elif act == 'leaky_relu':
act = nn.LeakyReLU()
self.seq = nn.Sequential(
Conv3DLayer(in_channels, 20, (1, 5, 5), pooling=(1, 2, 2),
dropout_rate=dropout_rate, act=act),
Conv3DLayer(20, 30, (1, 5, 5), pooling=(1, 2, 2),
dropout_rate=dropout_rate, act=act),
Conv3DLayer(30, 40, (1, 4, 4), pooling=(1, 2, 2),
dropout_rate=dropout_rate, act=act),
Conv3DLayer(40, 50, (1, 4, 4), pooling=(1, 2, 2),
dropout_rate=dropout_rate, act=act),
Conv3DLayer(50, 60, (1, 2, 2), pooling=(1, 2, 2),
dropout_rate=dropout_rate, act=act),
Conv3DLayer(60, 70, (1, 1, 1), pooling=(1, 2, 2),
dropout_rate=dropout_rate, act=act),
Conv3DLayer(70, 70, (1, 1, 1), pooling=(1, 1, 1),
dropout_rate=dropout_rate, act=act),
) # given: torch.Size([1, 4, 20, 128, 256]), returns torch.Size([1, 31, 20, 1, 3])
self.fc = nn.Sequential(
nn.Linear(4200 + n_scalar, 100),
act,
nn.Linear(100, 50),
act,
nn.Linear(50, n_classes),
)
def forward(self, *args):
x, scal = args
x = self.seq(x)
x = x.view(x.size()[0], -1)
x = torch.cat((x, scal), 1)
x = self.fc(x)
return x
def get_model():
model = StackedConv2ScalarWithLatentAdd(in_channels=4, n_classes=11, n_scalar=2)
# model = StackedConv2Scalar(in_channels=4, n_classes=8)
return model
if __name__ == "__main__":
lr = 1e-3
lr_stepsize = 500
lr_dec = 0.985
batch_size = 40
n_classes = 11
parser = argparse.ArgumentParser(description='Train a network.')
parser.add_argument('--disable-cuda', action='store_true', help='Disable CUDA')
parser.add_argument('-n', '--exp-name',
default='',
help='Manually set experiment name')
parser.add_argument(
'-m', '--max-steps', type=int, default=200e3,
help='Maximum number of training steps to perform.'
)
parser.add_argument(
'-r', '--resume', metavar='PATH',
help='Path to pretrained model state dict or a compiled and saved '
'ScriptModule from which to resume training.'
)
parser.add_argument(
'-j', '--jit', metavar='MODE', default='onsave',
choices=['disabled', 'train', 'onsave'],
help="""Options:
"disabled": Completely disable JIT tracing;
"onsave": Use regular Python model for training, but trace it on-demand for saving training state;
"train": Use traced model for training and serialize it on disk"""
)
parser.add_argument('--sr', type=str, help='Save root', default=None)
parser.add_argument('--seed', default=0, help='Random seed', type=int)
parser.add_argument('--cval', default=None, help='Cross-validation split indicator.', type=int)
args = parser.parse_args()
if not args.disable_cuda and torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
print('Running on device: {}'.format(device))
# Don't move this stuff, it needs to be run this early to work
import elektronn3