Commit 3ab059fb authored by Philipp Schubert's avatar Philipp Schubert
Browse files

#30 add new contact site detection method `detect_cs_64bit` for 64 bit

#TODO add downstream handling of contact sites with tuple identifier
parent f22067c2
Pipeline #99949 failed with stage
in 2 minutes and 10 seconds
......@@ -12,7 +12,7 @@ Refactored version of SyConn for automated synaptic connectivity inference based
Current features:
- introduction of classes for handling of supervoxels (e.g. cell fragments, predicted cellular
organelles like mitochondria, vesicle clouds etc.) and agglomerated supervoxels
- prediction of sub-cellular structures, supervoxel extraction and mesh
- distributed prediction of sub-cellular structures, supervoxel extraction and mesh
generation
- (sub-) cellular compartment (spines, bouton and axon/dendrite/soma) and cell type classification with multiview- [\[2\]](https://www.nature.com/articles/s41467-019-10836-3) and with skeleton-based approaches [\[1\]](https://www.nature.com/articles/nmeth.4206)
- glia identification and separation [\[2\]](https://www.nature.com/articles/s41467-019-10836-3)
......
......@@ -64,9 +64,11 @@ if __name__ == "__main__":
pred_key_appendix2 = 'celltype_CV{}/celltype_cmn_j0251v2_adam_nbviews20_longRUN_2ratios_BIG_bs40_10fold_CV{}_eval0'.format(cv, cv)
print('Loading cv-{}-data of model {}'.format(cv, pred_key_appendix2))
m_path = base_dir + pred_key_appendix2
pred_key_appendix2 += '_cmn'
m = InferenceModel(m_path, bs=80)
for ssv_id in ssv_ids:
ssv = ssd.get_super_segmentation_object(ssv_id)
ssv.load_attr_dict()
# predict
ssv.nb_cpus = 20
ssv._view_caching = True
......@@ -74,6 +76,7 @@ if __name__ == "__main__":
view_props={'use_syntype': True, 'nb_views': 20}, overwrite=False,
save_to_attr_dict=False, verbose=True,
model_props={'n_classes': nclasses, 'da_equals_tan': False})
ssv.save_attr_dict()
# GT
curr_l = ssv_label_dc[ssv.id]
gt_l.append(curr_l)
......@@ -90,7 +93,7 @@ if __name__ == "__main__":
print(f'{pred_l[-1]}\t{gt_l[-1]}\t{ssv.id}\t{major_dec}')
certainty.append(ssv.certainty_celltype("celltype_cnn_e3{}_probas".format(pred_key_appendix2)))
assert set(loaded_ssv_ids) == len(ssv_label_dc)
assert len(set(loaded_ssv_ids)) == len(ssv_label_dc)
# # WRITE OUT COMBINED RESULTS
pred_proba = np.array(pred_proba)
certainty = np.array(certainty)
......
......@@ -16,6 +16,10 @@ from logging import Logger
from typing import Optional, Dict, List, Tuple, Union
from multiprocessing import Process
import numba
from numba import typed
import numpy as np
import scipy.ndimage
import tqdm
......@@ -417,7 +421,14 @@ def _contact_site_extraction_thread(args: Union[tuple, list]) \
start = time.time()
# contacts has size as given with `size`, because it performs valid conv.
# -> contacts result is cropped by stencil_offset on each side
contacts = np.asarray(detect_cs(data))
# TODO: use new detect_cs after verification
# contacts = np.asarray(detect_cs(data))
contacts = np.asarray(detect_cs_64bit(data))
res = np.zeros(contacts.shape[:3], dtype=np.uint64)
mask = contacts[..., 0] != 0
res[mask] = (contacts[mask][..., 0] << 32) + contacts[mask][..., 1]
contacts = res
cum_dt_proc += time.time() - start
start = time.time()
......@@ -771,3 +782,114 @@ def detect_cs(arr: np.ndarray) -> np.ndarray:
cs_seg = process_block_nonzero(
edges, arr, global_params.config['cell_objects']['cs_filtersize'])
return cs_seg
def detect_cs_64bit(arr: np.ndarray) -> np.ndarray:
"""
Uses :func:`detect_seg_boundaries` to generate initial contact mask.
Args:
arr: 3D segmentation array
Returns:
4D contact site segmentation array (XYZC; with C=2).
"""
# first identify boundary voxels
bdry = detect_seg_boundaries(arr)
stencil = np.array(global_params.config['cell_objects']['cs_filtersize'])
assert np.sum(stencil % 2) == 3
offset = stencil // 2
# extract adjacent majority ID on sparse boundary voxels
offset = np.array([(-offset[0], offset[0]), (-offset[1], offset[1]), (-offset[2], offset[2])])
cs_seg = detect_contact_partners(arr, bdry, offset)
return cs_seg
@numba.jit(nopython=True)
def detect_contact_partners(seg_arr: np.ndarray, edge_arr: np.ndarray, offset: np.ndarray) -> np.ndarray:
"""
Identify whether IDs differ within `offset` and return boundary mask. Resulting array will be ``2*offset`` smaller
than input `seg_arr` ("valid convolution").
Args:
seg_arr: Segmentation volume (XYZ).
edge_arr: Boundary/edge mask array (XYZ). Inspects location if != 0, skips if 0.
offset: Offset for all spatial axes. Must have shape (3, 2). E.g. [(-1, 1), (-1, 1), (-1, 1)]
will check a 3x3x3 cube around every voxel.
Returns:
Boundary mask. Axes will be ``2*offset`` smaller.
"""
nx, ny, nz = seg_arr.shape[:3]
contact_partners = np.zeros((nx+offset[0, 0]-offset[0, 1],
ny+offset[1, 0]-offset[1, 1],
nz+offset[2, 0]-offset[2, 1], 2
), dtype=np.uint64)
for xx in range(-offset[0, 0], nx-offset[0, 1]):
for yy in range(-offset[1, 0], ny-offset[1, 1]):
for zz in range(-offset[2, 0], nz-offset[2, 1]):
center_id = seg_arr[xx, yy, zz]
if edge_arr[xx, yy, zz] == 0:
continue
d = typed.Dict.empty(
key_type=numba.uint64,
value_type=numba.uint64,
)
# inspect cube around center voxel
for neigh_x in range(offset[0, 0], offset[0, 1]):
for neigh_y in range(offset[1, 0], offset[1, 1]):
for neigh_z in range(offset[2, 0], offset[2, 1]):
neigh_id = seg_arr[xx + neigh_x, yy + neigh_y, zz + neigh_z]
if (neigh_id == 0) or (neigh_id == center_id):
continue
if neigh_id in d:
d[neigh_id] += 1
else:
d[neigh_id] = 1
if len(d) != 0:
# get most common ID
most_comm = 0
most_comm_cnt = 0
for k, v in d.items():
if most_comm_cnt < v:
most_comm = k
most_comm_cnt = v
partners = [most_comm, center_id] if center_id > most_comm else [center_id, most_comm]
contact_partners[xx+offset[0, 0], yy+offset[1, 0], zz+offset[2, 0]] = partners
return contact_partners
@numba.jit(nopython=True)
def detect_seg_boundaries(arr: np.ndarray) -> np.ndarray:
"""
Identify whether IDs differ within 6-connectivity and return boundary mask.
0 IDs are skipped.
Args:
arr: Segmentation volume (XYZ).
Returns:
Boundary mask.
"""
nx, ny, nz = arr.shape[:3]
boundary = np.zeros((nx, ny, nz), dtype=np.bool_)
for xx in range(nx):
for yy in range(ny):
for zz in range(nz):
center_id = arr[xx, yy, zz]
# no need to flag background
if center_id == 0:
continue
for neigh_x, neigh_y, neigh_z in [(-1, 0, 0), (1, 0, 0), (0, -1, 0), (0, 1, 0),
(0, 0, -1), (0, 0, 1)]:
if (xx + neigh_x < 0) or (xx + neigh_x >= nx):
continue
if (yy + neigh_y < 0) or (yy + neigh_y >= ny):
continue
if (zz + neigh_z < 0) or (zz + neigh_z >= nz):
continue
neigh_id = arr[xx + neigh_x, yy + neigh_y, zz + neigh_z]
boundary[xx, yy, zz] = (neigh_id != center_id) or boundary[xx, yy, zz]
return boundary
......@@ -505,6 +505,8 @@ def _combine_and_split_syn_thread(args):
attr_dc.push()
mesh_dc.push()
cur_path_id += 1
if len(voxel_rel_paths) == cur_path_id:
raise ValueError(f'Worker ran out of possible storage paths for storing {sd_syn_ssv.type}.')
n_items_for_path = 0
id_chunk_cnt = 0
base_id = ix_from_subfold(voxel_rel_paths[cur_path_id], sd_syn.n_folders_fs)
......@@ -677,7 +679,7 @@ def _combine_and_split_cs_thread(args):
ccs = gen_mesh_voxelmask(chain(*vxl_iter_lst), scale=scaling, **meshing_kws)
for mesh_cc in ccs:
abs_offset = np.min(mesh_cc[1], axis=0) // scaling
abs_offset = np.min(mesh_cc[1].reshape((-1, 3)), axis=0) // scaling
cs_ssv = sd_cs_ssv.get_segmentation_object(cs_ssv_id)
if (os.path.abspath(cs_ssv.attr_dict_path)
!= os.path.abspath(base_dir + "/attr_dict.pkl")):
......@@ -691,7 +693,7 @@ def _combine_and_split_cs_thread(args):
csssv_attr_dc["mesh_bb"] = cs_ssv.mesh_bb
csssv_attr_dc["mesh_area"] = cs_ssv.mesh_area
csssv_attr_dc["bounding_box"] = cs_ssv.mesh_bb // scaling
csssv_attr_dc["rep_coord"] = mesh_cc[1][0] // scaling # take first vertex coordinate
csssv_attr_dc["rep_coord"] = mesh_cc[1].reshape((-1, 3))[0] // scaling # take first vertex coordinate
# create open3d mesh instance to compute volume
# # TODO: add this as soon open3d >= 0.11 is supported (glibc error on cluster prevents upgrade)
......@@ -723,6 +725,8 @@ def _combine_and_split_cs_thread(args):
attr_dc.push()
mesh_dc.push()
cur_path_id += 1
if len(voxel_rel_paths) == cur_path_id:
raise ValueError(f'Worker ran out of possible storage paths for storing {sd_cs_ssv.type}.')
n_items_for_path = 0
id_chunk_cnt = 0
base_id = ix_from_subfold(voxel_rel_paths[cur_path_id], sd_cs.n_folders_fs)
......
......@@ -1215,7 +1215,8 @@ def gen_mesh_voxelmask(voxel_iter: Iterator[Tuple[np.ndarray, np.ndarray]], scal
the 3D array border are identified correctly.
Returns:
Indices, vertices, normals of the mesh. List[ind, vert, norm] if `compute_connected_components=True`.
Flat Index/triangle, vertex and normals array of the mesh. List[ind, vert, norm] if
`compute_connected_components=True`.
"""
vertex_size = np.array(vertex_size)
if boundary_struct is None:
......@@ -1281,7 +1282,11 @@ def gen_mesh_voxelmask(voxel_iter: Iterator[Tuple[np.ndarray, np.ndarray]], scal
mesh = mesh_
else:
mesh = [mesh]
mesh = [[np.asarray(m.triangles), np.asarray(m.vertices), np.asarray(m.vertex_normals)] for m in mesh]
for ii in range(len(mesh)):
m = mesh[ii]
verts = np.asarray(m.vertices).flatten()
verts[verts < 0] = 0
mesh[ii] = [np.asarray(m.triangles).flatten(), verts, np.asarray(m.vertex_normals).flatten()]
return mesh
......
......@@ -3,7 +3,7 @@
# All rights reserved
from syconn.reps.rep_helper import find_object_properties
from syconn.extraction.cs_extraction_steps import detect_cs
from syconn.extraction.cs_extraction_steps import detect_cs, detect_seg_boundaries, detect_cs_64bit
import numpy as np
from syconn.global_params import config
from syconn.handler.basics import chunkify_weighted
......@@ -47,9 +47,10 @@ def test_find_object_properties():
"Bounding box dictionary mismatch."
def _helpertest_detect_cs(distance_between_cube, stencil, cube_size):
def _helpertest_detect_cs(distance_between_cube, stencil, cube_size, test_func=detect_cs):
"""
Assert statement fails if detect_cs() method does not work properly
Assert statement fails if test_func method does not work properly (
func:`~syconn.extraction.cs_extraction_steps.detect_cs`).
Args:
distance_between_cube: Distance between cubes of two different ids
......@@ -59,10 +60,40 @@ def _helpertest_detect_cs(distance_between_cube, stencil, cube_size):
Returns:
"""
sample, expected_ids_low, expected_ids_high = _gen_sample_seg(distance_between_cube, stencil, cube_size)
edge_id_output_sample = test_func(sample)
higher_id_array = np.asarray(edge_id_output_sample, np.uint32) #retracts 32 bit cell id of higher value
lower_id_array = np.asarray(edge_id_output_sample, np.uint64) // (2 ** 32) #retracts 32 bit cell id of lower value
assert np.array_equal(np.array(expected_ids_high, np.uint32), np.array(higher_id_array, np.uint32)), \
"higher value cell id array do not match"
assert np.array_equal(np.array(expected_ids_low, np.uint32), np.array(lower_id_array, np.uint32)), \
"lower value cell id array do not match"
def _helpertest_detect_cs_64bit(distance_between_cube, stencil, cube_size):
"""
Assert statement fails if test_func method does not work properly (
func:`~syconn.extraction.cs_extraction_steps.detect_cs`).
Args:
distance_between_cube: Distance between cubes of two different ids
stencil: Generic stencil size
cube_size: Generic cube size of two different ids
Returns:
"""
sample, expected_ids_low, expected_ids_high = _gen_sample_seg(distance_between_cube, stencil, cube_size)
cs = detect_cs_64bit(sample)
assert np.array_equal(np.array(expected_ids_high, np.uint32), cs[..., 1]), \
"higher value cell id array do not match"
assert np.array_equal(np.array(expected_ids_low, np.uint32), cs[..., 0]), \
"lower value cell id array do not match"
def _gen_sample_seg(distance_between_cube, stencil, cube_size):
assert (np.amax(distance_between_cube) > cube_size), "Distance between cubes should be grater than cube size"
stencil = stencil
cube_size = cube_size #cube size
distance_between_cube = distance_between_cube #distance between cube
offset = stencil // 2 #output offset adjustment due to stencil size
a = np.amax(offset + 1) #co-ordinate of topmost corner of first cube
edge_s = np.amax(stencil + distance_between_cube + cube_size) # data cube size
......@@ -71,26 +102,20 @@ def _helpertest_detect_cs(distance_between_cube, stencil, cube_size):
d = distance_between_cube #dummy variable
sample[a:a+c, a:a+c, a:a+c] = 4 #cell_id cube 1
sample[a+d[0]:a+d[0]+c, a+d[1]:a+d[1]+c, a+d[2]:a+d[2]+c] = 5 #cell_id cube 2
edge_id_output_sample = detect_cs(sample)
higher_id_array = np.asarray(edge_id_output_sample, np.uint32) #retracts 32 bit cell id of higher value
lower_id_array = np.asarray(edge_id_output_sample, np.uint64) // (2 ** 32) #retracts 32 bit cell id of lower value
counter = d - offset #checks if distance between cubes is longer than stencil size
output_offset = np.maximum(0, counter) #adjusts output offset accordingly
output_shape = np.array(sample.shape + np.array([1, 1, 1]) - stencil)
o_o = output_offset #dummy variable for output cube size
o = offset #dummy variable for offset due to stencil size
output_id = np.zeros((output_shape[0], output_shape[1], output_shape[2]), dtype=np.uint32)
otuput_mask = np.zeros((output_shape[0], output_shape[1], output_shape[2]), dtype=np.uint32)
output_id[a-o[0]+o_o[0]:a+c-o[0], a-o[1]+o_o[1]:a+c-o[1], a-o[2]+o_o[2]:a+c-o[2]] = 1
output_id[a+d[0]-o[0]:a+d[0]+c-o[0]-o_o[0], a+d[1]-o[1]:a+d[1]+c-o[1]-o_o[1], a+d[2]-o[2]:a+d[2]+c-o[2]-o_o[2]] = 1
output_id[a-o[0]+1:a+c-o[0]-1, a-o[1]+1:a+c-o[1]-1, a-o[2]+1:a+c-o[2]-1] = 0
output_id[a+d[0]-o[0]+1:a+d[0]+c-o[0]-1, a+d[1]-o[1]+1:a+d[1]+c-o[1]-1, a+d[2]-o[2]+1:a+d[2]+c-o[2]-1] = 0
assert np.array_equal(np.array(5*output_id, np.uint32), np.array(higher_id_array, np.uint32)), \
"higher value cell id array do not match"
assert np.array_equal(np.array(4*output_id, np.uint32), np.array(lower_id_array, np.uint32)), \
"lower value cell id array do not match"
otuput_mask[a-o[0]+o_o[0]:a+c-o[0], a-o[1]+o_o[1]:a+c-o[1], a-o[2]+o_o[2]:a+c-o[2]] = 1
otuput_mask[a+d[0]-o[0]:a+d[0]+c-o[0]-o_o[0], a+d[1]-o[1]:a+d[1]+c-o[1]-o_o[1], a+d[2]-o[2]:a+d[2]+c-o[2]-o_o[2]] = 1
otuput_mask[a-o[0]+1:a+c-o[0]-1, a-o[1]+1:a+c-o[1]-1, a-o[2]+1:a+c-o[2]-1] = 0
otuput_mask[a+d[0]-o[0]+1:a+d[0]+c-o[0]-1, a+d[1]-o[1]+1:a+d[1]+c-o[1]-1, a+d[2]-o[2]+1:a+d[2]+c-o[2]-1] = 0
return sample, 4 * otuput_mask, 5 * otuput_mask
def test_detect_cs():
......@@ -102,6 +127,22 @@ def test_detect_cs():
np.array(config['cell_objects']['cs_filtersize'], dtype=np.int32), 5)
def test_detect_cs_64bit():
_helpertest_detect_cs_64bit(np.array([0, 6, 0]),
np.array(config['cell_objects']['cs_filtersize'], dtype=np.int32), 5)
_helpertest_detect_cs_64bit(np.array([6, 0, 0]),
np.array(config['cell_objects']['cs_filtersize'], dtype=np.int32), 5)
_helpertest_detect_cs_64bit(np.array([0, 0, 6]),
np.array(config['cell_objects']['cs_filtersize'], dtype=np.int32), 5)
def test_boundary_gen():
bdry = detect_seg_boundaries(np.arange(1000).reshape((10, 10, 10)))
assert np.all(bdry)
bdry = detect_seg_boundaries(np.zeros((10, 10, 10)))
assert not np.all(bdry)
def test_chunk_weighted():
sample_array = np.array([0, 1, 2, 3, 4, 5, 6, 7], np.uint64)
weights = np.array([3, 1, 2, 7, 5, 8, 0, 8], np.uint64)
......@@ -142,8 +183,4 @@ def test_colorcode_vertices(grid_size=5, number_of_test_vertices=50):
if __name__ == '__main__':
test_chunk_weighted()
test_colorcode_vertices(5, 50)
_helpertest_detect_cs(np.array([0, 6, 0]), np.array(config['cell_objects']['cs_filtersize'], dtype=np.int32), 5)
_helpertest_detect_cs(np.array([6, 0, 0]), np.array(config['cell_objects']['cs_filtersize'], dtype=np.int32), 5)
_helpertest_detect_cs(np.array([0, 0, 6]), np.array(config['cell_objects']['cs_filtersize'], dtype=np.int32), 5)
test_detect_cs_64bit()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment