From 88b0672fa8f3531c3f8b471e858336eb51a20ecb Mon Sep 17 00:00:00 2001
From: arother <rother@neuro.mpg.de>
Date: Tue, 25 Apr 2023 17:22:16 +0200
Subject: [PATCH] updated str2int, int2str_converter with glia cell classes

---
 syconn/extraction/cs_processing_steps.py | 14 ++++++++++++++
 syconn/handler/prediction.py             |  9 +++++++++
 2 files changed, 23 insertions(+)

diff --git a/syconn/extraction/cs_processing_steps.py b/syconn/extraction/cs_processing_steps.py
index 6108f789..0f202f01 100755
--- a/syconn/extraction/cs_processing_steps.py
+++ b/syconn/extraction/cs_processing_steps.py
@@ -1338,17 +1338,23 @@ def create_syn_rfc(sd_syn_ssv: 'segmentation.SegmentationDataset', path2file: st
         df = pandas.read_excel(path2file, header=0, names=[
             'ixs', 'coord', 'pre', 'post', 'syn', 'doublechecked', 'triplechecked', '?', 'comments']).values
         df = df[:, :7]
+        synaptic = 0
+        non_synaptic = 0
         for ix in range(df.shape[0]):
             c_orig = df[ix, 5]
             c = df[ix, 6]
             if type(c) != float and 'yes' in c:
                 unified_comment = 'synaptic'
+                synaptic += 1
             elif type(c) != float and 'no' in c:
                 unified_comment = 'non-synaptic'
+                non_synaptic += 1
             elif 'yes' in c_orig:
                 unified_comment = 'synaptic'
+                synaptic += 1
             elif 'no' in c_orig:
                 unified_comment = 'non-synaptic'
+                non_synaptic += 1
             else:
                 log.warn(f'Did not understand GT comment "{c}". Skipping')
                 continue
@@ -1357,6 +1363,7 @@ def create_syn_rfc(sd_syn_ssv: 'segmentation.SegmentationDataset', path2file: st
 
     labels = np.array(labels)
     label_coords = np.array(label_coords)
+    log.info(f'Before filtering: {synaptic} synaptic labels and {non_synaptic} non_synaptic labels')
 
     # get deterministic order by sorting by coordinate first and then seeded shuffling
     ixs = [i[0] for i in sorted(enumerate(label_coords),
@@ -1389,8 +1396,15 @@ def create_syn_rfc(sd_syn_ssv: 'segmentation.SegmentationDataset', path2file: st
     if np.sum(mask) == 0:
         raise ValueError
     synssv_ids = synssv_ids[mask]
+    not_mapped_labels = labels[mask == 0]
+    not_mapped_syn = not_mapped_labels[not_mapped_labels == 'synaptic']
+    not_mapped_nonsyn = not_mapped_labels[not_mapped_labels == 'non-synaptic']
     labels = labels[mask]
+    mapped_syn = labels[labels == 'synaptic']
+    mapped_nonsyn = labels[labels == 'non-synaptic']
     log.info(f'Found {np.sum(mask)}/{len(mask)} samples with a distance < {max_dist_vx} vx to the target.')
+    log.info(f'Excluding {len(not_mapped_syn)} synaptic labels and {len(not_mapped_nonsyn)} non_synaptic labels after filtering')
+    log.info(f'Training with {len(mapped_syn)} synaptic labels and {len(mapped_nonsyn)} non_synaptic labels after filtering')
 
     log.info(f'Synapse features will now be generated.')
     features = []
diff --git a/syconn/handler/prediction.py b/syconn/handler/prediction.py
index caaf35b8..761081e4 100755
--- a/syconn/handler/prediction.py
+++ b/syconn/handler/prediction.py
@@ -1324,6 +1324,10 @@ def str2int_converter(comment: str, gt_type: str) -> int:
         str2int_label = dict(STN=0, DA=1, MSN=2, LMAN=3, HVC=4, TAN=5, GPe=6, GPi=7,
                              FS=8, LTS=9, NGF=10)
         return str2int_label[comment]
+    elif gt_type == 'ctgt_j0251_v3':
+        str2int_label = dict(STN=0, DA=1, MSN=2, LMAN=3, HVC=4, TAN=5, GPe=6, GPi=7,
+                             FS=8, LTS=9, NGF=10, ASTRO=11, OLIGO=12, MICRO=13, FRAG=14)
+        return str2int_label[comment]
     else:
         raise ValueError("Given groundtruth type is not valid.")
 
@@ -1400,5 +1404,10 @@ def int2str_converter(label: int, gt_type: str) -> str:
                              FS=8, LTS=9, NGF=10)
         int2str_label = {v: k for k, v in str2int_label.items()}
         return int2str_label[label]
+    elif gt_type == 'ctgt_j0251_v3':
+        str2int_label = dict(STN=0, DA=1, MSN=2, LMAN=3, HVC=4, TAN=5, GPe=6, GPi=7,
+                             FS=8, LTS=9, NGF=10, ASTRO=11, OLIGO=12, MICRO=13, FRAG=14)
+        int2str_label = {v: k for k, v in str2int_label.items()}
+        return int2str_label[label]
     else:
         raise ValueError("Given ground truth type is not valid.")
-- 
GitLab