Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import argparse | |
| import os.path as osp | |
| import nibabel as nib | |
| import numpy as np | |
| from mmengine.utils import mkdir_or_exist | |
| from PIL import Image | |
| def read_files_from_txt(txt_path): | |
| with open(txt_path) as f: | |
| files = f.readlines() | |
| files = [file.strip() for file in files] | |
| return files | |
| def read_nii_file(nii_path): | |
| img = nib.load(nii_path).get_fdata() | |
| return img | |
| def split_3d_image(img): | |
| c, _, _ = img.shape | |
| res = [] | |
| for i in range(c): | |
| res.append(img[i, :, :]) | |
| return res | |
| def label_mapping(label): | |
| """Label mapping from TransUNet paper setting. It only has 9 classes, which | |
| are 'background', 'aorta', 'gallbladder', 'left_kidney', 'right_kidney', | |
| 'liver', 'pancreas', 'spleen', 'stomach', respectively. Other foreground | |
| classes in original dataset are all set to background. | |
| More details could be found here: https://arxiv.org/abs/2102.04306 | |
| """ | |
| maped_label = np.zeros_like(label) | |
| maped_label[label == 8] = 1 | |
| maped_label[label == 4] = 2 | |
| maped_label[label == 3] = 3 | |
| maped_label[label == 2] = 4 | |
| maped_label[label == 6] = 5 | |
| maped_label[label == 11] = 6 | |
| maped_label[label == 1] = 7 | |
| maped_label[label == 7] = 8 | |
| return maped_label | |
| def pares_args(): | |
| parser = argparse.ArgumentParser( | |
| description='Convert synapse dataset to mmsegmentation format') | |
| parser.add_argument( | |
| '--dataset-path', type=str, help='synapse dataset path.') | |
| parser.add_argument( | |
| '--save-path', | |
| default='data/synapse', | |
| type=str, | |
| help='save path of the dataset.') | |
| args = parser.parse_args() | |
| return args | |
| def main(): | |
| args = pares_args() | |
| dataset_path = args.dataset_path | |
| save_path = args.save_path | |
| if not osp.exists(dataset_path): | |
| raise ValueError('The dataset path does not exist. ' | |
| 'Please enter a correct dataset path.') | |
| if not osp.exists(osp.join(dataset_path, 'img')) \ | |
| or not osp.exists(osp.join(dataset_path, 'label')): | |
| raise FileNotFoundError('The dataset structure is incorrect. ' | |
| 'Please check your dataset.') | |
| train_id = read_files_from_txt(osp.join(dataset_path, 'train.txt')) | |
| train_id = [idx[3:7] for idx in train_id] | |
| test_id = read_files_from_txt(osp.join(dataset_path, 'val.txt')) | |
| test_id = [idx[3:7] for idx in test_id] | |
| mkdir_or_exist(osp.join(save_path, 'img_dir/train')) | |
| mkdir_or_exist(osp.join(save_path, 'img_dir/val')) | |
| mkdir_or_exist(osp.join(save_path, 'ann_dir/train')) | |
| mkdir_or_exist(osp.join(save_path, 'ann_dir/val')) | |
| # It follows data preparation pipeline from here: | |
| # https://github.com/Beckschen/TransUNet/tree/main/datasets | |
| for i, idx in enumerate(train_id): | |
| img_3d = read_nii_file( | |
| osp.join(dataset_path, 'img', 'img' + idx + '.nii.gz')) | |
| label_3d = read_nii_file( | |
| osp.join(dataset_path, 'label', 'label' + idx + '.nii.gz')) | |
| img_3d = np.clip(img_3d, -125, 275) | |
| img_3d = (img_3d + 125) / 400 | |
| img_3d *= 255 | |
| img_3d = np.transpose(img_3d, [2, 0, 1]) | |
| img_3d = np.flip(img_3d, 2) | |
| label_3d = np.transpose(label_3d, [2, 0, 1]) | |
| label_3d = np.flip(label_3d, 2) | |
| label_3d = label_mapping(label_3d) | |
| for c in range(img_3d.shape[0]): | |
| img = img_3d[c] | |
| label = label_3d[c] | |
| img = Image.fromarray(img).convert('RGB') | |
| label = Image.fromarray(label).convert('L') | |
| img.save( | |
| osp.join( | |
| save_path, 'img_dir/train', 'case' + idx.zfill(4) + | |
| '_slice' + str(c).zfill(3) + '.jpg')) | |
| label.save( | |
| osp.join( | |
| save_path, 'ann_dir/train', 'case' + idx.zfill(4) + | |
| '_slice' + str(c).zfill(3) + '.png')) | |
| for i, idx in enumerate(test_id): | |
| img_3d = read_nii_file( | |
| osp.join(dataset_path, 'img', 'img' + idx + '.nii.gz')) | |
| label_3d = read_nii_file( | |
| osp.join(dataset_path, 'label', 'label' + idx + '.nii.gz')) | |
| img_3d = np.clip(img_3d, -125, 275) | |
| img_3d = (img_3d + 125) / 400 | |
| img_3d *= 255 | |
| img_3d = np.transpose(img_3d, [2, 0, 1]) | |
| img_3d = np.flip(img_3d, 2) | |
| label_3d = np.transpose(label_3d, [2, 0, 1]) | |
| label_3d = np.flip(label_3d, 2) | |
| label_3d = label_mapping(label_3d) | |
| for c in range(img_3d.shape[0]): | |
| img = img_3d[c] | |
| label = label_3d[c] | |
| img = Image.fromarray(img).convert('RGB') | |
| label = Image.fromarray(label).convert('L') | |
| img.save( | |
| osp.join( | |
| save_path, 'img_dir/val', 'case' + idx.zfill(4) + | |
| '_slice' + str(c).zfill(3) + '.jpg')) | |
| label.save( | |
| osp.join( | |
| save_path, 'ann_dir/val', 'case' + idx.zfill(4) + | |
| '_slice' + str(c).zfill(3) + '.png')) | |
| if __name__ == '__main__': | |
| main() | |