Commit cb8e57c7 authored by mdraw's avatar mdraw
Browse files

Support custom pre-inference normalization params

Don't hardcode this to (0, 255) because not all models are trained with
normalized inputs in the same range [0, 1].
parent 8789b04a
Pipeline #139029 passed with stage
in 7 minutes
......@@ -599,7 +599,9 @@ def predict_dense_to_kd(kd_path: str, target_path: str, model_path: str,
overlap_shape_tiles: Tuple[int, int, int] = (40, 40, 20),
cube_of_interest: Optional[Tuple[np.ndarray]] = None,
overwrite: bool = False,
cube_shape_kd: Optional[Tuple[int]] = None):
cube_shape_kd: Optional[Tuple[int]] = None,
traindata_mean: float = 0.,
traindata_std: float = 255.):
"""
Helper function for dense dataset prediction. Runs predictions on the whole
knossos dataset located at `kd_path`.
......@@ -648,6 +650,11 @@ def predict_dense_to_kd(kd_path: str, target_path: str, model_path: str,
coordinate in voxels in the respective magnification (see kwarg `mag`).
overwrite: Overwrite existing KDs.
cube_shape_kd: Cube shape used to store sub-volumes in KnossosDataset on the file system.
traindata_mean: Mean value for pre-inference normalization. Will be subtracted from raw data.
Choose the value that the model was trained with. Default: 0.
traindata_std: Standard deviation value for pre-inference normalization
Raw data will be divided by this value. Default: 255.
Choose the value that the model was trained with.
"""
if log is None:
......@@ -709,7 +716,7 @@ def predict_dense_to_kd(kd_path: str, target_path: str, model_path: str,
multi_params = chunkify(multi_params, global_params.config.ngpu_total)
multi_params = [(ch_ids, kd_path, target_path, model_path, overlap_shape,
overlap_shape_tiles, tile_shape, chunk_size, n_channel, target_channels,
target_kd_path_list, channel_thresholds, mag, cube_of_interest)
target_kd_path_list, channel_thresholds, mag, cube_of_interest, traindata_mean, traindata_std)
for ch_ids in multi_params]
log.info('Started dense prediction of {} in {:d} chunk(s).'.format(", ".join(target_names), len(chunk_ids)))
n_cores_per_job = global_params.config['ncores_per_node'] // global_params.config['ngpus_per_node'] if \
......@@ -747,7 +754,7 @@ def dense_predictor(args):
# TODO: clean up (e.g. redundant chunk sizes, ...)
#
chunk_ids, kd_p, target_p, model_p, overlap_shape, overlap_shape_tiles, tile_shape, chunk_size, n_channel, \
target_channels, target_kd_path_list, channel_thresholds, mag, cube_of_interest = args
target_channels, target_kd_path_list, channel_thresholds, mag, cube_of_interest, traindata_mean, traindata_std = args
# init KnossosDataset:
kd = KnossosDataset()
......@@ -768,15 +775,20 @@ def dense_predictor(args):
# init Predictor
from elektronn3.inference import Predictor
from elektronn3.data import transforms
normalize_transform = transforms.Normalize(mean=traindata_mean, std=traindata_std)
ix = 0
tile_shape = np.array(tile_shape)
while True:
try:
out_shape = (chunk_size + 2 * np.array(overlap_shape)).astype(np.int32)[::-1] # ZYX
out_shape = np.insert(out_shape, 0, n_channel) # output must equal chunk size
# TODO: float16 inference
predictor = Predictor(model_p, strict_shapes=True, tile_shape=tile_shape[::-1],
out_shape=out_shape, overlap_shape=overlap_shape_tiles[::-1],
apply_softmax=True)
apply_softmax=True, transform=normalize_transform)
predictor.model.ae = False
_ = predictor.predict(np.zeros(out_shape[1:])[None, None])
break
......@@ -803,9 +815,9 @@ def dense_predictor(args):
coords = np.array(np.array(ch.coordinates) - np.array(ol),
dtype=np.int32)
raw = kd.load_raw(size=size * mag, offset=coords * mag, mag=mag)
raw = kd.load_raw(size=size * mag, offset=coords * mag, mag=mag).astype(np.float32)
pred = dense_predicton_helper(raw.astype(np.float32) / 255., predictor,
pred = dense_predicton_helper(raw, predictor,
is_zyx=True, return_zyx=True)
# slice out the original input volume along ZYX, i.e. the last three axes
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment