Commit 34d9d058 authored by YangLiu14's avatar YangLiu14
Browse files

last commit

parent fc722024
......@@ -11,19 +11,21 @@ from syconn.reps.super_segmentation import *
from syconn.reps.segmentation import SegmentationDataset
from syconn.proc.meshes import merge_meshes
from syconn.handler.basics import data2kzip, write_obj2pkl, load_pkl2obj
from syconn.proc.ssd_proc import merge_ssv
# Parameters
### Path to load dataset
global_params.wd = '/wholebrain/songbird/j0126/areaxfs_v6/'
# global_params.wd = '/wholebrain/songbird/j0126/areaxfs_v6/'
# cs_version = 'agg_0'
# global_params.wd = '/wholebrain/songbird/j0126/areaxfs_v10_v4b_base_20180214_full_agglo_cbsplit/'
global_params.wd = '/ssdscratch/pschuber/songbird/j0126/areaxfs_v10_v4b_base_20180214_full_agglo_cbsplit/'
# global_params.wd = '/wholebrain/songbird/j0126/areaxfs_v10_newcb_cbsplits/'
cs_version = 0
# global_params.wd = '/home/kloping/wholebrain/songbird/j0126/areaxfs_v6/' # local test
# Path to store output kzip files
folder_name = "/merger_(256_128)_30720_(2e3_20e3)_10000/"
# folder_name = "/merger_(256_128)_30720_(2e3_20e3)_10000/"
folder_name = "/merger_with_edge/"
data_folder = global_params.wd.split('/')[-2]
suffix_list = data_folder.split('_')[1:]
pkl_version = suffix_list[0]
......@@ -59,10 +61,12 @@ def write_dict_to_txt(dict, fname):
f.write(str(dict))
f.close()
def read_dict_from_txt(fname: str):
f = open(fname, 'r')
return eval(f.read())
def cs_partner(id) -> Optional[List[int]]:
"""
Contact site specific attribute.
......@@ -74,35 +78,8 @@ def cs_partner(id) -> Optional[List[int]]:
partner.append(id - (partner[0] << 32))
return partner
def merge_superseg_objects(cell_obj1, cell_obj2):
# TODO: test why working_dir='/tmp/' raise some errors
# TODO: add the edge. General two closest nodes, add an edge between, add to SperSeghelper
# merge meshes
merged_cell = SuperSegmentationObject(ssv_id=-1, working_dir=None, version='tmp')
for mesh_type in ['sv', 'sj', 'syn_ssv', 'vc', 'mi']:
mesh1 = cell_obj1.load_mesh(mesh_type)
mesh2 = cell_obj2.load_mesh(mesh_type)
ind_lst = [mesh1[0], mesh2[0]]
vert_lst = [mesh1[1], mesh2[1]]
merged_cell._meshes[mesh_type] = merge_meshes(ind_lst, vert_lst)
merged_cell._meshes[mesh_type] += ([None, None], ) # add normals
# merge skeletons
merged_cell.skeleton = {}
cell_obj1.load_skeleton()
cell_obj2.load_skeleton()
merged_cell.skeleton['edges'] = np.concatenate([cell_obj1.skeleton['edges'],
cell_obj2.skeleton['edges'] +
len(cell_obj1.skeleton['nodes'])]) # additional offset
merged_cell.skeleton['nodes'] = np.concatenate([cell_obj1.skeleton['nodes'],
cell_obj2.skeleton['nodes']])
merged_cell.skeleton['diameters'] = np.concatenate([cell_obj1.skeleton['diameters'],
cell_obj2.skeleton['diameters']])
return merged_cell
def create_lookup_table(num_cs_id, dict_sv2svv):
def create_lookup_table(num_cs_id, contact_sites, dict_sv2svv):
"""Loop through all contact_sites ids and if two corresponding cells are found,
store the cell_id and corresponding cs_id into dictionary
......@@ -152,9 +129,8 @@ if __name__ == "__main__":
contact_sites = SegmentationDataset(obj_type='cs', version=cs_version) # class holding all contact-site between SVs
dict_sv2ssv = ssd.mapping_dict_reversed # dict: {supervoxel : super-supervoxel}
if create_new_cs_ids:
cell_pair2cs_ids, cell_pairs = create_lookup_table(num_cs_id, dict_sv2ssv)
cell_pair2cs_ids, cell_pairs = create_lookup_table(num_cs_id, contact_sites, dict_sv2ssv)
write_obj2pkl(path_pkl_file, cell_pair2cs_ids)
else:
cell_pair2cs_ids = dict()
......@@ -166,7 +142,7 @@ if __name__ == "__main__":
print("Loading cell_pair2cs_ids from pkl file successful.")
except:
print("cell_pair2cs_ids_" + pkl_version + ".pkl not found in {}".format(path_pkl_file))
cell_pair2cs_ids, cell_pairs = create_lookup_table(num_cs_id, dict_sv2ssv)
cell_pair2cs_ids, cell_pairs = create_lookup_table(num_cs_id, contact_sites, dict_sv2ssv)
write_obj2pkl(path_pkl_file, cell_pair2cs_ids)
assert len(cell_pair2cs_ids) == len(cell_pairs), "inconsistent length"
......@@ -178,16 +154,16 @@ if __name__ == "__main__":
print("in which {} pairs have more than one cs.".format(count))
print("Generating merged cells")
count = 1
# cell_pairs = cell_pairs[501:501+num_generated_cells]
cell_pairs = cell_pairs[:num_generated_cells]
for i in trange(len(cell_pairs), desc='cell_pairs'):
cell_pair = cell_pairs[i]
c1, c2 = cell_pair[0], cell_pair[1]
assert c1 != c2, "same cells cannot be merged."
cell_obj1, cell_obj2 = ssd.get_super_segmentation_object([c1, c2])
merged_cell = merge_superseg_objects(cell_obj1, cell_obj2)
# merged_cell = merge_superseg_objects(cell_obj1, cell_obj2)
merged_cell = merge_ssv(cell_obj1, cell_obj2)
merged_cell_nodes = merged_cell.skeleton['nodes'] * merged_cell.scaling # coordinates of all nodes
# labels:
......@@ -206,6 +182,8 @@ if __name__ == "__main__":
cs_coord = cs_obj.rep_coord * merged_cell.scaling
cs_coord_list.append(cs_coord)
if len(cs_coord_list) == 0:
print("No cs found: {}.".format(count))
count += 1
continue
# find medium cube around artificial merger and set it to 0 (no-merger/cell_body)
......@@ -249,4 +227,5 @@ if __name__ == "__main__":
# TODO: export2kzip should run through
count += 1
print("Total merged_cells generated: {}".format(count))
......@@ -116,7 +116,8 @@ def gt_generation(kzip_paths, dest_dir=None):
if not os.path.isdir(dest_dir):
os.makedirs(dest_dir)
dest_p_results = "{}/gt_convpoint/".format(dest_dir)
# dest_p_results = "{}/gt_convpoint/".format(dest_dir)
dest_p_results = "{}/gt_convpoint_edge/".format(dest_dir)
if not os.path.isdir(dest_p_results):
os.makedirs(dest_p_results)
......@@ -134,7 +135,8 @@ if __name__ == "__main__":
# set paths
dest_gt_dir = "/wholebrain/scratch/yliu/merger_gt_semseg_pointcloud/"
# global_params.wd = "/wholebrain/songbird/j0126/areaxfs_v6/"
label_file_folder = "/wholebrain/scratch/yliu/false_merger_generation/merger_CSfilter_kzip_v10_02/"
# label_file_folder = "/wholebrain/scratch/yliu/false_merger_generation/merger_CSfilter_kzip_v10_02/"
label_file_folder = "/wholebrain/scratch/yliu/false_merger_generation/merger_with_edge/"
file_paths = glob.glob(label_file_folder + '*.k.zip', recursive=False)
......
......@@ -47,21 +47,10 @@ setup(
keywords='machinelearning imageprocessing connectomics',
packages=find_packages(exclude=['scripts']),
python_requires='>=3.6, <4',
<<<<<<< HEAD
install_requires=['numpy==1.16.4', 'scipy', 'lz4', 'h5py', 'networkx',
'fasteners', 'flask', 'coloredlogs', 'opencv-python',
'pyopengl', 'scikit-learn>=0.21.3', 'scikit-image',
'plyfile', 'termcolor', 'dill', 'tqdm', 'zmesh',
'seaborn', 'pytest-runner', 'prompt-toolkit',
'numba==0.45.0', 'matplotlib', 'vtki', 'joblib',
'pyyaml', 'cython'],
setup_requires=setup_requires, tests_require=['pytest', 'pytest-cov'],
=======
setup_requires=setup_requires,
package_data={'syconn': ['handler/config.yml']},
include_package_data=True,
tests_require=['pytest', 'pytest-cov', 'pytest-xdist'],
>>>>>>> d7b8a8b9652050cca79384c362fc1099bced807c
ext_modules=cython_out,
entry_points={
'console_scripts': [
......
......@@ -40,22 +40,34 @@ from sys import getsizeof
# ==========================
# Directory pointed to the training / validation dataset
dataset_dir = '/wholebrain/scratch/yliu/merger_gt_semseg_v10_5views_200_6000/'
# dataset_dir = '/wholebrain/scratch/yliu/merger_gt_semseg_v10_5views_200_6000_v02/'
dataset_dir = '/wholebrain/scratch/yliu/merger_gt_semseg_mesh/merger_(512_256)_15360_(2e3_20e3)_10k/'
# Path to where the trained model is saved
save_root = os.path.expanduser('~/e3training/')
# save_root = os.path.expanduser('~/e3training/')
save_root = os.path.expanduser('~/Unet_(512_256)_15360_(2e3_20e3)_10k/')
# optimizer, choose from ['SGD', 'Adam']
opt = 'Adam'
# IMPORTANT: change this strictly according to the `ws` of your data
# see `ws` in generate_merger_gt_semseg.py
example_input = torch.randn(1, 4, 256, 512)
# example_input = torch.randn(1, 4, 128, 256)
# Hyper-parameters
# lr = 0.002
# lr_stepsize = 1000
# lr_dec = 0.99
# batch_size = 20
# Hyper-parameters
lr = 0.004
lr_stepsize = 500
lr_dec = 0.995
batch_size = 6
lr_stepsize = 1000
lr_dec = 0.99
batch_size = 10
def get_model():
vgg_model = VGGNet(model='vgg13', requires_grad=True, in_channels=4)
vgg_model = VGGNet(model='vgg19', requires_grad=True, in_channels=4)
model = FCNs(base_net=vgg_model, n_class=3)
# model = UNet(in_channels=4, out_channels=6, n_blocks=5, start_filts=32,
# merge_mode='concat', planar_blocks=(), #up_mode='resize',
......@@ -104,7 +116,7 @@ if __name__ == "__main__":
model = get_model()
model.to(device)
example_input = torch.randn(1, 4, 128, 256)
# example_input = torch.randn(1, 4, 256, 512)
enable_save_trace = False if args.jit == 'disabled' else True
if args.jit == 'onsave':
......@@ -196,7 +208,9 @@ if __name__ == "__main__":
# valid_dataset = ModMultiviewData(train=False, transform=transform, base_dir=global_params.config['compartments']['gt_path_axonseg'])
# criterion = LovaszLoss().to(device)
criterion = DiceLoss().to(device)
# criterion = DiceLoss().to(device)
criterion = DiceLoss(apply_softmax=True, weight=torch.tensor([0.4, 0.5, 0.1]).to(device)).to(device)
valid_metrics = {
# 'val_accuracy': metrics.bin_accuracy,
......
......@@ -30,7 +30,8 @@ except ImportError:
from elektronn3.training import metrics
from elektronn3.models.fcn_2d import *
from elektronn3.models.unet import UNet
from elektronn3.models.unet_plusplus import NestedUNet
from elektronn3.models.unets.nested_unet import NestedUNet
from elektronn3.models.unets.attention_unet import AttU_Net, R2AttU_Net
from elektronn3.data.transforms import RandomFlip
from elektronn3.data import transforms
from sys import getsizeof
......@@ -40,39 +41,44 @@ from sys import getsizeof
# ==========================
# Directory pointed to the training / validation dataset
dataset_dir = '/wholebrain/scratch/yliu/merger_gt_semseg_v10_5views_200_6000/'
# dataset_dir = '/wholebrain/scratch/yliu/merger_gt_semseg_v10_5views_200_6000_v02/'
dataset_dir = '/wholebrain/scratch/yliu/merger_gt_semseg_mesh/merger_(512_256)_15360_(2e3_20e3)_10k/'
# Path to where the trained model is saved
save_root = os.path.expanduser('~/e3training/')
# save_root = os.path.expanduser('~/e3training/')
save_root = os.path.expanduser('~/Unet_(512_256)_15360_(2e3_20e3)_10k/')
# optimizer, choose from ['SGD', 'Adam']
opt = 'Adam'
# Hyper-parameters
lr = 0.004
lr_stepsize = 500
lr_dec = 0.995
batch_size = 6
def get_model():
# ===================
# FCN
# ===================
# vgg_model = VGGNet(model='vgg13', requires_grad=True, in_channels=4)
# model = FCNs(base_net=vgg_model, n_class=3)
# ===================
# U-Net
# ===================
model = UNet(in_channels=4, out_channels=3, n_blocks=5, start_filts=32,
merge_mode='concat', planar_blocks=(), #up_mode='resize',
activation='relu', batch_norm=True, dim=2,)
# ===================
# U-Net++
# ===================
# model = NestedUNet(in_channels=4, out_channels=3, deepsupervision=False)
# model = NestedUNet(in_channels=4, out_channels=3, deepsupervision=True)
# lr = 0.001
lr_stepsize = 1000
lr_dec = 0.99
batch_size = 10 # batch_size=20 doesn't work with Unet++
# IMPORTANT: change this strictly according to the `ws` of your data
# see `ws` in generate_merger_gt_semseg.py
example_input = torch.randn(1, 4, 256, 512)
# example_input = torch.randn(1, 4, 128, 256)
def get_model(network="unet", unet_blocks=6, deepsupervision=False):
model = None
if network == "unet" or network == "":
model = UNet(in_channels=4, out_channels=3, n_blocks=unet_blocks, start_filts=32,
merge_mode='concat', planar_blocks=(), #up_mode='resize',
activation='relu', batch_norm=True, dim=2,)
elif network == "unet++":
model = NestedUNet(in_channels=4, out_channels=3, deepsupervision=deepsupervision)
elif network == "attention-unet":
model = AttU_Net(in_channels=4, out_channels=3)
elif network == "rcnn-attention-unet":
model = R2AttU_Net(in_channels=4, out_channels=3)
else:
Exception("Invalid network name, currently only ['unet', 'unet++', 'attention-unet', 'rcnn-attention-unet']")
return model
......@@ -98,6 +104,10 @@ if __name__ == "__main__":
"onsave": Use regular Python model for training, but trace it on-demand for saving training state;
"train": Use traced model for training and serialize it on disk"""
)
parser.add_argument('--network', type=str, default="rcnn-attention-unet", help='network architecture')
parser.add_argument('--unet-blocks', type=int, default=6, help='number of downsampling/upsampling layers in unet')
parser.add_argument('--deepsupervision', type=bool, default=False, help='number of downsampling/upsampling layers in unet')
args = parser.parse_args()
if not args.disable_cuda and torch.cuda.is_available():
device = torch.device('cuda')
......@@ -114,22 +124,26 @@ if __name__ == "__main__":
max_steps = args.max_steps
model = get_model()
model = get_model(args.network, args.unet_blocks, args.deepsupervision)
# print("U-Net: number of blocks: {}".format(model.n_blocks))
# print(model)
model.to(device)
example_input = torch.randn(1, 4, 128, 256)
enable_save_trace = False if args.jit == 'disabled' else True
if args.jit == 'onsave':
# Make sure that tracing works
tracedmodel = torch.jit.trace(model, example_input.to(device))
elif args.jit == 'train':
if getattr(model, 'checkpointing', False):
raise NotImplementedError(
'Traced models with checkpointing currently don\'t '
'work, so either run with --disable-trace or disable '
'checkpointing.')
tracedmodel = torch.jit.trace(model, example_input.to(device))
model = tracedmodel
if not args.deepsupervision and args.network != 'rcnn-attention-unet':
# example_input = torch.randn(1, 4, 128, 256)
enable_save_trace = False if args.jit == 'disabled' else True
if args.jit == 'onsave':
# Make sure that tracing works
tracedmodel = torch.jit.trace(model, example_input.to(device))
elif args.jit == 'train':
if getattr(model, 'checkpointing', False):
raise NotImplementedError(
'Traced models with checkpointing currently don\'t '
'work, so either run with --disable-trace or disable '
'checkpointing.')
tracedmodel = torch.jit.trace(model, example_input.to(device))
model = tracedmodel
optimizer_state_dict = None
lr_sched_state_dict = None
......@@ -209,7 +223,8 @@ if __name__ == "__main__":
# valid_dataset = ModMultiviewData(train=False, transform=transform, base_dir=global_params.config['compartments']['gt_path_axonseg'])
# criterion = LovaszLoss().to(device)
criterion = DiceLoss().to(device)
# criterion = DiceLoss().to(device)
criterion = DiceLoss(apply_softmax=True, weight=torch.tensor([0.4, 0.5, 0.1]).to(device)).to(device)
valid_metrics = {
# 'val_accuracy': metrics.bin_accuracy,
......
......@@ -256,6 +256,7 @@ def triangulation(pts, downsampling=(1, 1, 1), n_closings=0, single_cc=False,
assert (pts.ndim == 2 and pts.shape[1] == 3) or pts.ndim == 3, \
"Point cloud used for mesh generation has wrong shape."
if pts.ndim == 2:
print("pts.ndim == 2")
if np.max(pts) <= 1:
msg = "Currently this function only supports point " \
"clouds with coordinates >> 1."
......@@ -294,17 +295,23 @@ def triangulation(pts, downsampling=(1, 1, 1), n_closings=0, single_cc=False,
np.float32)
n_dilations += 1
else:
volume = volume.astype(np.float32)
print("TEST: n_closing <= 0 ")
# volume = volume.astype(np.float32)
print("TEST: volume")
if single_cc:
print("TEST: single_cc")
labeled, nb_cc = ndimage.label(volume)
cnt = Counter(labeled[labeled != 0])
l, occ = cnt.most_common(1)[0]
volume = np.array(labeled == l, dtype=np.float32)
# InterpixelBoundary, OuterBoundary, InnerBoundary
print("TEST: boundaryDistanceTransform")
dt = boundaryDistanceTransform(volume, boundary="InterpixelBoundary")
dt[volume == 1] *= -1
print("TEST: gaussianSmoothing, 1")
volume = gaussianSmoothing(dt, 1)
if np.sum(volume < 0) == 0 or np.sum(volume > 0) == 0: # less smoothing
print("TEST: gaussianSmoothing, 0.5")
volume = gaussianSmoothing(dt, 0.5)
try:
verts, ind, norm, _ = measure.marching_cubes_lewiner(
......@@ -1138,10 +1145,11 @@ def mesh2obj_file(dest_path, mesh, color=None, center=None, scale=None):
# options += openmesh.Options.Binary
mesh_obj = openmesh.TriMesh()
ind, vert, norm = mesh
if vert.ndim == 1:
vert = vert.reshape(-1 ,3)
vert = vert.reshape(-1, 3)
if ind.ndim == 1:
ind = ind.reshape(-1 ,3)
ind = ind.reshape(-1, 3)
if center is not None:
vert -= center
if scale is not None:
......
......@@ -2157,7 +2157,7 @@ def semseg_of_sso_nocache(sso, model, semseg_key: str, ws: Tuple[int, int],
be predicted with the given `model` and maps prediction results onto mesh.
Vertex labels are stored on file system and can be accessed via
`sso.label_dict('vertex')[semseg_key]`.
If sso._sample_locations is None it `generate_rendering_locs(verts, comp_window / 3)`
If sso._sample_locations is None it `semseg_of_sso_nocache(verts, comp_window / 3)`
will be called to generate rendering locations.
Examples:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment