Commit d2f9aeed authored by Philipp's avatar Philipp
Browse files

#30 add `calc_contact_syn_mesh`: a voxel-based mesh generation for syn/syn_ssv/cs objects

parent d47a3c97
Pipeline #98235 passed with stage
in 2 minutes and 25 seconds
......@@ -19,12 +19,9 @@ build-full:
- conda env create -f environment.yml -n pysyintegration_m python --force
- source ~/.bashrc
- conda activate pysyintegration_m
only:
- master
- merge_requests
except:
variables:
- $CI_MERGE_REQUEST_TARGET_BRANCH_NAME != "master"
rules:
- if: '$CI_COMMIT_BRANCH == "master"'
- if: '$CI_MERGE_REQUEST_TARGET_BRANCH_NAME == "master"'
test-full:
stage: test
......@@ -33,23 +30,17 @@ test-full:
- conda activate pysyintegration_m
- pip install --upgrade --no-deps -v -e .
- python -m pytest -c tests/full_run.ini
only:
- master
- merge_requests
except:
variables:
- $CI_MERGE_REQUEST_TARGET_BRANCH_NAME != "master"
rules:
- if: '$CI_COMMIT_BRANCH == "master"'
- if: '$CI_MERGE_REQUEST_TARGET_BRANCH_NAME == "master"'
cleanup:
stage: cleanup
script:
- conda remove --yes -n pysyintegration_m --all
when: always
only:
- master
- merge_requests
except:
variables:
- $CI_MERGE_REQUEST_TARGET_BRANCH_NAME != "master"
rules:
- if: '$CI_COMMIT_BRANCH == "master"'
- if: '$CI_MERGE_REQUEST_TARGET_BRANCH_NAME == "master"'
pylint:
stage: test
......@@ -66,9 +57,6 @@ pylint:
- ./pylint/
expire_in: 2 yrs
allow_failure: true
only:
- master
- merge_requests
except:
variables:
- $CI_MERGE_REQUEST_TARGET_BRANCH_NAME != "master"
\ No newline at end of file
rules:
- if: '$CI_COMMIT_BRANCH == "master"'
- if: '$CI_MERGE_REQUEST_TARGET_BRANCH_NAME == "master"'
\ No newline at end of file
......@@ -49,7 +49,7 @@ used for our project can be found
SyConn uses the packages [zmesh](https://github.com/seung-lab/zmesh) for mesh and [kimimaro](https://github.com/seung-lab/kimimaro)
for skeleton generation implemented and developed in the Seung Lab.
Thanks to Julia Kuhl (see http://somedonkey.com/ for more beautiful
work) for designing and creating the logo! and
work) for designing and creating the logo!
Publications
......
......@@ -55,8 +55,8 @@ dependencies:
# From -c menpo
- osmesa
# From -c open3d-admin; 0.10.0 throws /lib64/libm.so.6: version `GLIBC_2.27' not found
- open3d =0.9.0
# From -c open3d-admin; >0.9.0 throws /lib64/libm.so.6: version `GLIBC_2.27' not found
- open3d >=0.9.0
# For tests (optional):
- pytest
......@@ -80,5 +80,10 @@ dependencies:
- git+https://github.com/StructuralNeurobiologyLab/MorphX.git@master#egg=morphx
- git+https://github.com/StructuralNeurobiologyLab/kimimaro.git
# training scripts
- git+https://github.com/StructuralNeurobiologyLab/NeuronX.git
- # PointConv definitions
- git+https://github.com/valeoai/LightConvPoint.git
#for skeletonisation
- fill-voids
......@@ -8,7 +8,11 @@ from syconn.proc.meshes import write_mesh2kzip
from morphx.classes.hybridcloud import HybridCloud
from utils import anno_skeleton2np, sso2kzip, nxGraph2kzip, map_myelin
from syconn.reps.super_segmentation import SuperSegmentationDataset
from syconn.mp.mp_utils import start_multiprocess_imap
try:
import open3d as o3d
except ImportError:
pass # for sphinx build
col_lookup = {0: (76, 92, 158, 255), 1: (255, 125, 125, 255), 2: (125, 255, 125, 255), 3: (113, 98, 227, 255),
4: (255, 255, 125, 255), 5: (125, 255, 255, 255), 6: (255, 125, 255, 255), 7: (168, 0, 20, 255),
......@@ -16,7 +20,24 @@ col_lookup = {0: (76, 92, 158, 255), 1: (255, 125, 125, 255), 2: (125, 255, 125,
12: (255, 127, 15, 255)}
def process_file(file: str, o_path: str, ctype: str, ssd: SuperSegmentationDataset, convert_to_morphx: bool = False):
def voxelize_points(pts, voxel_size):
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(pts)
pcd = pcd.voxel_down_sample(voxel_size=voxel_size)
return np.asarray(pcd.points)
def _process_file(args):
return process_file(*args)
def process_file(file: str, o_path: str, ctype: str, convert_to_morphx: bool = False):
print(f'Processing: {file}')
ssd = SuperSegmentationDataset(working_dir='/ssdscratch/pschuber/songbird/j0251/rag_flat_Jan2019_v3/')
# Point cloud reduction
voxel_sizes = dict(sv=100, mi=120, sy=120, vc=120)
sso_id = int(max(re.findall('(\d+)', file.replace(a_path, '')), key=len))
kzip_path = o_path + f'{ctype}_{sso_id}.k.zip'
......@@ -26,16 +47,15 @@ def process_file(file: str, o_path: str, ctype: str, ssd: SuperSegmentationDatas
sso = ssd.get_super_segmentation_object(sso_id)
scaling = sso.scaling
# if 'DA_' in file or 'HVC_' in file or '10074977' in file:
# scaling = np.array([10, 10, 20])
# if 'HVC_53854647' in file:
# scaling = sso.scaling
a_coords, a_edges, a_labels, a_labels_raw, graph = anno_skeleton2np(file, scaling, verbose=True, convert_to_morphx=convert_to_morphx)
a_coords, a_edges, a_labels, a_labels_raw, graph = anno_skeleton2np(file, scaling, verbose=False,
convert_to_morphx=convert_to_morphx)
indices, vertices, normals = sso.mesh
indices, vertices, _ = sso.mesh
vertices = vertices.reshape((-1, 3))
# TODO: voxelize requires adaption of indices in the case of convert_to_morphx=False
if convert_to_morphx:
vertices = voxelize_points(vertices, voxel_size=voxel_sizes['sv'])
labels = np.ones((len(vertices), 1)) * -1
indices = indices.reshape((-1, 3))
cell = HybridCloud(vertices=vertices, labels=labels, nodes=a_coords, edges=a_edges, node_labels=a_labels)
# map labels from nodes to vertices
cell.nodel2vertl()
......@@ -69,15 +89,16 @@ def process_file(file: str, o_path: str, ctype: str, ssd: SuperSegmentationDatas
label_map = [20, 21, 22]
clouds = {}
for ix, mesh in enumerate(meshes):
indices, vertices, normals = mesh
_, vertices, _ = mesh
vertices = vertices.reshape((-1, 3))
vertices = voxelize_points(vertices, voxel_size=voxel_sizes[organelles[ix]])
labels = np.ones((len(vertices), 1)) * label_map[ix]
organelle = HybridCloud(vertices=vertices, labels=labels)
organelle.set_encoding({organelles[ix]: label_map[ix]})
clouds[organelles[ix]] = organelle
# --- add myelin to main cell and merge main cell with organelles ---
# hc = map_myelin(sso, hc)
cell = map_myelin(sso, cell)
ce = CloudEnsemble(clouds, cell, no_pred=organelles)
ce.save2pkl(f'{o_path}/sso_{sso.id}.pkl')
......@@ -87,16 +108,17 @@ if __name__ == '__main__':
o_path = '/wholebrain/scratch/pschuber/compartments_j0251/hybrid_clouds_refined01/'
if not os.path.exists(o_path):
os.makedirs(o_path)
ssd = SuperSegmentationDataset(working_dir='/ssdscratch/pschuber/songbird/j0251/rag_flat_Jan2019_v2/')
files = os.listdir(a_path)
args = []
for file in files:
print(f'Processing: {file}')
if os.path.isdir(file):
kzips = glob.glob(a_path + file + '/*k.zip')
for kzip in tqdm(kzips):
print(f'Processing: {kzip}')
process_file(kzip, o_path, file[:3], ssd)
args.append([kzip, o_path, file[:3]])
else:
# set convert_to_morphx = False to only generate new colorings of kzips
process_file(a_path + file, o_path, file[:3], ssd, convert_to_morphx=True)
args.append([a_path + file, o_path, file[:3], True])
start_multiprocess_imap(_process_file, args, nb_cpus=10)
......@@ -26,7 +26,7 @@ if __name__ == '__main__':
name = today + '_{}'.format(chunk_size) + '_{}'.format(sample_num)
argscont = ArgsContainer(save_root='/wholebrain/scratch/pschuber/compartments_j0251/models_refined01/',
train_path='/wholebrain/scratch/pschuber/compartments_j0251/hybrid_clouds_refined01/',
train_path='/wholebrain/scratch/pschuber/compartments_j0251/hybrid_clouds_refined01/train/',
sample_num=sample_num,
name=name + f'_{i}',
random_seed=i,
......@@ -40,8 +40,8 @@ if __name__ == '__main__':
batch_size=batch_size,
input_channels=1,
use_val=True,
architecture='randla_net',
val_path='/wholebrain/scratch/pschuber/compartments_j0251/hybrid_clouds_refined01/',
model='randla_net',
val_path='/wholebrain/scratch/pschuber/compartments_j0251/hybrid_clouds_refined01/test/',
val_freq=30,
features={'hc': np.array([1])},
chunk_size=chunk_size,
......
from syconn.handler import basics, training
from syconn.mp.batchjob_utils import batchjob_script
if __name__ == '__main__':
nfold = 10
params = []
cnn_script = '/wholebrain/u/pschuber/devel/SyConn/syconn/cnn/cnn_celltype_ptcnv_j0251.py'
for npoints, ctx in ([25000, 15000], ):
scale = int(ctx / 10)
save_root = f'/wholebrain/scratch/pschuber/e3_trainings_randla_celltypes_j0251/' \
f'celltype_pts{npoints}_ctx{ctx}_allGT/'
params.append([cnn_script, dict(sr=save_root, sp=npoints, cval=-1, seed=0, ctx=ctx,
scale_norm=scale)])
params = list(basics.chunkify_successive(params, 1))
batchjob_script(params, 'launch_trainer', n_cores=10, additional_flags='--time=7-0 --qos=720h --gres=gpu:1',
disable_batchjob=False,
batchjob_folder=f'/wholebrain/scratch/pschuber/batchjobs'
f'/launch_trainer_celltypes_j0251_allGT_randla/',
remove_jobfolder=False, overwrite=True, exclude_nodes=[])
......@@ -274,12 +274,15 @@ class VoxelStorageDyn(CompressedStorage):
return super().__setitem__(key, value)
def __getitem__(self, item: int):
return self.get_voxelmask_offset(item)
def get_voxelmask_offset(self, item: int, overlap: int = 0):
if self.voxel_mode:
res = []
bbs = super().__getitem__(item)
for bb in bbs: # iterate over all bounding boxes
size = bb[1] - bb[0]
off = bb[0]
size = bb[1] - bb[0] + 2 * overlap
off = bb[0] - overlap
curr_mask = self.voxeldata.load_seg(size=size, offset=off, mag=1) == item
res.append(curr_mask.swapaxes(0, 2))
return res, bbs[:, 0] # (N, 3) --> all offset
......
# ELEKTRONN3 - Neural Network Toolkit
#
# Copyright (c) 2019 - now
# Max Planck Institute of Neurobiology, Munich, Germany
# Authors: Philipp Schubert
from syconn.cnn.TrainData import CellCloudDataJ0251
import os
import torch
import argparse
import random
import numpy as np
# Don't move this stuff, it needs to be run this early to work
import elektronn3
elektronn3.select_mpl_backend('Agg')
import morphx.processing.clouds as clouds
from torch import nn
from elektronn3.models.randla_net import RandLANetClassification
from elektronn3.training import Trainer3d, Backup, metrics
import distutils
# PARSE PARAMETERS #
parser = argparse.ArgumentParser(description='Train a network.')
parser.add_argument('--na', type=str, help='Experiment name',
default=None)
parser.add_argument('--sr', type=str, help='Save root', default=None)
parser.add_argument('--bs', type=int, default=10, help='Batch size')
parser.add_argument('--sp', type=int, default=50000, help='Number of sample points')
parser.add_argument('--scale_norm', type=int, default=2000, help='Scale factor for normalization')
parser.add_argument('--co', action='store_true', help='Disable CUDA')
parser.add_argument('--seed', default=0, help='Random seed', type=int)
parser.add_argument('--ctx', default=20000, help='Context size in nm', type=int)
parser.add_argument('--use_syntype', default=1, help='Use synapse type',
type=distutils.util.strtobool)
parser.add_argument('--cellshape_only', default=0, help='Use only cell surface points',
type=distutils.util.strtobool)
parser.add_argument(
'-j', '--jit', metavar='MODE', default='disabled', # TODO: does not work
choices=['disabled', 'train', 'onsave'],
help="""Options:
"disabled": Completely disable JIT tracing;
"onsave": Use regular Python model for training, but trace it on-demand for saving training state;
"train": Use traced model for training and serialize it on disk"""
)
parser.add_argument('--cval', default=None, help='Cross-validation split indicator.', type=int)
args = parser.parse_args()
# SET UP ENVIRONMENT #
random_seed = args.seed
torch.manual_seed(random_seed)
np.random.seed(random_seed)
random.seed(random_seed)
# define parameters
use_cuda = not args.co
name = args.na
batch_size = args.bs
npoints = args.sp
scale_norm = args.scale_norm
save_root = args.sr
cval = args.cval
ctx = args.ctx
use_syntype = args.use_syntype
cellshape_only = args.cellshape_only
lr = 5e-4
lr_stepsize = 100
lr_dec = 0.99
max_steps = 500000
# celltype specific
eval_nr = random_seed # number of repetition
dr = 0.3
num_classes = 11
onehot = True
act = 'relu'
use_myelin = True
if name is None:
name = f'celltype_pts_randla_j0251v2_scale{scale_norm}_nb{npoints}_ctx{ctx}_{act}'
if cellshape_only:
name += '_cellshapeOnly'
else:
if not use_syntype:
name += '_noSyntype'
if use_myelin:
name += '_myelin'
if onehot:
input_channels = 4
if use_syntype:
input_channels += 1
if use_myelin:
input_channels += 1
else:
input_channels = 1
name += '_flatinp'
if cellshape_only:
input_channels = 1
if use_cuda:
device = torch.device('cuda')
else:
device = torch.device('cpu')
print(f'Running on device: {device}')
# set paths
if save_root is None:
save_root = '~/e3_trainings_convpoint_celltypes_j0251_randla/'
save_root = os.path.expanduser(save_root)
# CREATE NETWORK AND PREPARE DATA SET
# Model selection
model = RandLANetClassification(input_channels, num_classes, dropout_p=dr)
if cval is not None:
name += f'_CV{cval}'
else:
name += f'_AllGT'
name += f'_eval{eval_nr}'
model = nn.DataParallel(model)
if use_cuda:
model.to(device)
example_input = (torch.ones(batch_size, npoints, input_channels).to(device),
torch.ones(batch_size, npoints, 3).to(device))
enable_save_trace = False if args.jit == 'disabled' else True
if args.jit == 'onsave':
# Make sure that tracing works
tracedmodel = torch.jit.trace(model, example_input)
elif args.jit == 'train':
if getattr(model, 'checkpointing', False):
raise NotImplementedError(
'Traced models with checkpointing currently don\'t '
'work, so either run with --disable-trace or disable '
'checkpointing.')
tracedmodel = torch.jit.trace(model, example_input)
model = tracedmodel
# Transformations to be applied to samples before feeding them to the network
train_transform = clouds.Compose([clouds.RandomVariation((-40, 40), distr='normal'), # in nm
clouds.Center(),
clouds.Normalization(scale_norm),
clouds.RandomRotate(apply_flip=True),
clouds.ElasticTransform(res=(40, 40, 40), sigma=6),
clouds.RandomScale(distr_scale=0.1, distr='uniform')])
valid_transform = clouds.Compose([clouds.Center(), clouds.Normalization(scale_norm)])
train_ds = CellCloudDataJ0251(npoints=npoints, transform=train_transform, cv_val=cval,
cellshape_only=cellshape_only, use_syntype=use_syntype,
onehot=onehot, batch_size=batch_size, ctx_size=ctx, map_myelin=use_myelin)
# valid_ds = CellCloudDataJ0251(npoints=npoints, transform=valid_transform, train=False,
# cv_val=cval, cellshape_only=cellshape_only,
# use_syntype=use_syntype, onehot=onehot, batch_size=batch_size,
# ctx_size=ctx, map_myelin=use_myelin)
valid_ds = None
# PREPARE AND START TRAINING #
# set up optimization
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# optimizer = torch.optim.SGD(
# model.parameters(),
# lr=lr, # Learning rate is set by the lr_sched below
# momentum=0.9,
# weight_decay=0.5e-5,
# )
# optimizer = SWA(optimizer) # Enable support for Stochastic Weight Averaging
lr_sched = torch.optim.lr_scheduler.StepLR(optimizer, lr_stepsize, lr_dec)
# extra weight for HVC and LMAN
# STN=0, DA=1, MSN=2, LMAN=3, HVC=4, GP=5, TAN=6, INT=7
criterion = torch.nn.CrossEntropyLoss() # weight=torch.Tensor([1]*num_classes))
if use_cuda:
criterion.cuda()
valid_metrics = { # mean metrics
'val_accuracy_mean': metrics.Accuracy(),
'val_precision_mean': metrics.Precision(),
'val_recall_mean': metrics.Recall(),
'val_DSC_mean': metrics.DSC(),
'val_IoU_mean': metrics.IoU(),
}
# Create trainer
# it seems pytorch 1.1 does not support batch_size=None to enable batched dataloader, instead
# using batch size 1 with custom collate_fn
trainer = Trainer3d(
model=model,
criterion=criterion,
optimizer=optimizer,
device=device,
train_dataset=train_ds,
valid_dataset=valid_ds,
batchsize=1,
num_workers=20,
valid_metrics=valid_metrics,
save_root=save_root,
enable_save_trace=enable_save_trace,
exp_name=name,
schedulers={"lr": lr_sched},
num_classes=num_classes,
# example_input=example_input,
dataloader_kwargs=dict(collate_fn=lambda x: x[0]),
nbatch_avg=10,
)
# Archiving training script, src folder, env info
bk = Backup(script_path=__file__,
save_path=trainer.save_path).archive_backup()
# Start training
trainer.run(max_steps)
......@@ -136,11 +136,11 @@ cell_objects:
thresh_mi_bbd_mapping: 25000 # bounding box diagonal in NM
# --------- CONTACT SITE AND SYNAPSE PARAMETERS
cs_filtersize: [13, 13, 7]
cs_nclosings: 7 # TODO: bind to cs_filtersize
# used for agglomerating 'syn' objects (cell supervoxel-based synapse fragments)
# into 'syn_ssv'
cs_gap_nm: 250
cs_filtersize: [13, 13, 7]
cs_nclosings: 7 # TODO: bind to cs_filtersize
# Parameters of agglomerated synapses 'syn_ssv'
# mapping parameters in 'map_objects_to_synssv'; assignment of cellular
# organelles to syn_ssv
......@@ -171,11 +171,23 @@ meshes:
mesh_min_obj_vx: 100 # adapt to size threshold
# used for mitochondria, vesicle clouds, and cell supervoxel meshing
meshing_props:
normals: False
simplification_factor: 500
max_simplification_error: 40 # in nm
# used for cs and syn_ssv
meshing_props_points:
cs:
depth: 11 # detail level used in open3d create_from_point_cloud_poisson
vertex_size: 80 # in nm; used for mesh simplification
min_num_vert: 200 # minimum number of vertices in a single connected component mesh
syn_ssv:
depth: 11
vertex_size: 80
min_num_vert: 200
skeleton:
# If True, allow cell skeleton generation from rendering locations (inaccurate).
allow_ssv_skel_gen: True
......
......@@ -7,6 +7,7 @@
import itertools
import warnings
import copy
from collections import Counter
from typing import Optional, List, Tuple, Dict, Union, Iterable, TYPE_CHECKING
......@@ -16,10 +17,12 @@ from numba import jit
from plyfile import PlyData, PlyElement
from scipy import spatial, ndimage
from scipy.ndimage import zoom
from scipy.ndimage.morphology import binary_closing, binary_dilation
from scipy.ndimage.morphology import binary_closing, binary_dilation, binary_erosion
from skimage import measure
from sklearn.decomposition import PCA
from zmesh import Mesher
import open3d as o3d
from vigra.filters import gaussianGradient
from .image import apply_pca
from .. import global_params
......@@ -57,14 +60,13 @@ try:
from .in_bounding_boxC import in_bounding_box
except ImportError:
from .in_bounding_box import in_bounding_box
log_proc.error('ImportError. Could not import `in_boundinb_box` from '
'`syconn/proc.in_bounding_boxC`. Fallback to numba jit.')
if TYPE_CHECKING:
from ..reps import segmentation
from ..reps import super_segmentation_object
__all__ = ['MeshObject', 'get_object_mesh', 'merge_meshes', 'triangulation',
__all__ = ['MeshObject', 'get_object_mesh', 'merge_meshes', 'triangulation', 'calc_contact_syn_mesh',
'get_random_centered_coords', 'write_mesh2kzip', 'write_meshes2kzip',
'compartmentalize_mesh', 'mesh_chunk', 'mesh_creator_sso', 'merge_meshes_incl_norm',
'mesh_area_calc', 'mesh2obj_file', 'calc_rot_matrices', 'merge_someshes', 'find_meshes']
......@@ -1184,3 +1186,105 @@ def mesh_area_calc(mesh):
"""
return mesh_surface_area(mesh[1].reshape(-1, 3),
mesh[0].reshape(-1, 3)) / 1e6
def _gen_mesh_voxelmask(mask_list: List[np.ndarray], offset_list: List[np.ndarray], scale: np.ndarray,
vertex_size: float = 80, boundary_struct: Optional[np.ndarray] = None,
depth: int = 11, compute_connected_components: bool = True,
min_vert_num: int = 200, overlap: int = 1) \
-> Union[List[np.ndarray], List[List[np.ndarray]]]:
"""
Args:
mask_list: Binary voxel maks, list of 3D cubes.
offset_list: Cube offsets (in voxels), for each cube in `mask_list`.
scale: Size of voxels in `mask_list` in nm (x, y, z).
vertex_size: In nm. Resolution used to simplify mesh.
boundary_struct: Connectivity of kernel used to determine boundary
depth: http://www.open3d.org/docs/latest/tutorial/Advanced/surface_reconstruction.html#Poi
sson-surface-reconstruction :
"An important parameter of the function is depth that defines the depth of the octree
used for the surface reconstruction and hence implies the resolution of the resulting
triangle mesh. A higher depth value means a mesh with more details."
compute_connected_components:< Compute connected components of mesh. Return list of meshes.
min_vert_num: Minimum number of vertices of the connected component meshes (only applied if
`compute_connected_components=True`).
overlap: Overlap between adjacent masks in `mask_list`.
Notes: Use `mask_list` with cubes with 1-voxel-overlap to guarantee that boundaries that align with
the 3D array border are identified correctly.
Returns:
Indices, vertices, normals of the mesh. List[ind, vert, norm] if `compute_connected_components=True`.
"""
vertex_size = np.array(vertex_size)
if boundary_struct is None:
# 26-connected
boundary_struct = np.ones((3, 3, 3))
pts, norm = [], []
for m, off in zip(mask_list, offset_list):
bndry = m.astype(np.float32) - binary_erosion(m, boundary_struct, iterations=1)
m = m[overlap:-overlap, overlap:-overlap, overlap:-overlap]
bndry = bndry[overlap:-overlap, overlap:-overlap, overlap:-overlap]
try:
grad = gaussianGradient(m.astype(np.float32), 1) # sigma=1
except RuntimeError: # PreconditionViolation (current mask cube is smaller than kernel)
m = np.pad(m, 5)
grad = gaussianGradient(m.astype(np.float32), 1)[5:-5, 5:-5, 5:-5]
# mult. by -1 to make normals point outwards
mag = -np.linalg.norm(grad, axis=-1)
grad[mag != 0] /= mag[mag != 0][..., None]
nonzero_mask = np.nonzero(bndry)
if np.abs(mag[nonzero_mask]).min() == 0: