This repository contains the code for the paper High-Resolution Swin Transformer for Automatic Medical Image Segmentation.
- Python 3.8
- Pytorch 1.8.1
- monai 0.9.0rc2+2.gbbc628d9 (results in paper submitted to hindawi)
- monai 0.9.1 (new results)
- nnunet pip install nnunet
- pip install axial-attention==0.5.0
- pip install mmcv-full
- pip install einops
- pip install SimpleITK
- pip install tensorboardX
- pip install omegaconf
- pip install fvcore
- Ubuntu 18.04
- cuda 10.2
- cudnn 8.1.1
- Modify the parameters about datasets path in the
configs/dataset_path.yamlfile.
- The config file for training BraTS 2021 dataset is
configs/hr_trans_brats_2021_seg.yaml. - The following parameters in the above config file should be specified before training:
- MODEL.HR_TRANS.STAGE_NUM: the number of stages in the HRSTNet.
- DATALOADER.TRAIN_WORKERS: the number of workers of pytorch dataloader for training. This parameter should be set according to your computer configuration.
- SOLVER.EPOCHS: training epochs.
- OUTPUT_DIR: the path to save the training log and model files.
- INPUT.RAND_CROP.SAMPLES: the number of random crop samples in each batch, and the real batch size equals INPUT.RAND_CROP.SAMPLES*SOLVER.BATCH_SIZE.
- Using the file
tools_train.pyto train the model. - Modify the following parameters i the
tools_train.pyfile before training:- --config_file: the path to the config file that is used to train model.
- --num-gpus: gpu count used to train model.
- The config file for training MSD dataset is
configs/hr_trans_liver_seg.yaml. - The following parameters in the above config file should be specified before training:
- MODEL.HR_TRANS.STAGE_NUM: the number of stages in the HRSTNet.
- DATALOADER.TRAIN_WORKERS: the number of workers of pytorch dataloader for training. This parameter should be set according to your computer configuration.
- SOLVER.EPOCHS: training epochs.
- OUTPUT_DIR: the path to save the training log and model files.
- INPUT.RAND_CROP.SAMPLES: the number of random crop samples in each batch, and the real batch size equals INPUT.RAND_CROP.SAMPLES*SOLVER.BATCH_SIZE.
- Using the file
tools_train.pyto train the model. - Modify the following parameters i the
tools_train.pyfile before training:- --config_file: the path to the config file that is used to train model.
- --num-gpus: gpu count used to train model.
- The config file for training MSD dataset is
configs/hr_trans_abdomen_seg.yaml. - The following parameters in the above config file should be specified before training:
- MODEL.HR_TRANS.STAGE_NUM: the number of stages in the HRSTNet.
- DATALOADER.TRAIN_WORKERS: the number of workers of pytorch dataloader for training. This parameter should be set according to your computer configuration.
- SOLVER.EPOCHS: training epochs.
- OUTPUT_DIR: the path to save the training log and model files.
- INPUT.RAND_CROP.SAMPLES: the number of random crop samples in each batch, and the real batch size equals INPUT.RAND_CROP.SAMPLES*SOLVER.BATCH_SIZE.
- Using the file
tools_train.pyto train the model. - Modify the following parameters i the
tools_train.pyfile before training:- --config_file: the path to the config file that is used to train model.
- --num-gpus: gpu count used to train model.
After training the models by using the tools_train.py file, the performance of models are evaluated by using the tools_inference.py file.
If running the evaluate_*.py encounter the following error:
ITK ERROR: ITK only supports orthonormal direction cosines. No orthonormal definition found!
Running the code evaluate/fix_simpleitk_read_error.py file to fix the wrong file, and setting the parameter pred_file_path to the wrong file path, and the original_img_path to the corresponding image path.
If the model is trained by using the vt_unet preprocessing method, the model is evaluated by the following method (taking hrstnet as example).
- modify the parameters
MODEL.WEIGHTS,OUTPUT_DIR, andMODEin theconfigs/hr_trans_brats_2021_seg.yaml. - set the parameter
space=brats_2021_vt_unet, andconfig_file=confgs/configs/hr_trans_brats_2021_seg.yamlin thetools_inference.pyfile. - Running the
tools_inference.pyfile, and the segmentation results will be generated in the folderOUTPUT_DIR/seg_results. - The segmentation masks in folder
OUTPUT_DIR/seg_resultscan be used to visualize in the 3D Slicer software. - Running the code
evaluate/evaluate_brats_vt_unet.py, and setting the parametersinferts_path, andpathtoOUTPUT_DIR/seg_results, andground_truth_path, respectively. - The evaluation results will appear in
OUTPUT_DIR/seg_results/dice_pre.txt.
Taking the Spleen dataset from MSD as an example.
- modify the parameters
MODEL.WEIGHTS,OUTPUT_DIR, andMODEin theconfigs/hr_spleen_seg.yaml. - set the parameter
space=original_msd, andconfig_file=confgs/configs/hr_trans_spleen_seg.yamlin thetools_inference.pyfile. - Running the
tools_inference.pyfile, and the segmentation results will be generated in the folderOUTPUT_DIR/seg_results/. - The segmentation masks in folder
OUTPUT_DIR/seg_resultscan be used to visualize in the 3D Slicer software. - Running the code
evaluate/evaluate_msd.py, and setting the parameterspred_path,MSD_TYPE,CATEGORIES,MODEandgt_pathtoOUTPUT_DIR/seg_results,Spleen,2,VALIDATEandground_truth_path, respectively. - The evaluation results will appear in
OUTPUT_DIR/seg_results/dice_pred.txt. - Running the code
evaluate/evaluate_msd.py, and setting the parameterMODEtoVALto check does all the generated files are correct.
- modify the parameters
MODEL.WEIGHTS,OUTPUT_DIR,DATASETS.TEST_TYPE, andMODEin theconfigs/hr_trans_abdomen_seg.yaml. - set the parameter
space=original_abdomen, andconfig_file=confgs/configs/hr_trans_abdomen_seg.yamlin thetools_inference.pyfile. - Running the
tools_inference.pyfile, and the segmentation results will be generated in the folderOUTPUT_DIR/seg_results/. - Utilizing the
evaluate/fix_aliginment_error.pyfile to modify the origin and direction of saved CT segmentation file, otherwise the segmentation mask can not display correctly. Changing thepred_img_folder, and thegt_img_folderparameters.
pred_img_folder = "/home/ljm/Fdisk_A/train_outputs/train_output_medical_2022_8/hrstnet/abdomen_seg_hrstnet_stages_4/seg_results/"
gt_img_folder = "/home/ljm/Fdisk_A/train_datasets/train_datasets_medical/2015_Segmentation_Cranial Vault Challenge/Abdomen/RawData/Training/img/"- The segmentation masks in folder
OUTPUT_DIR/seg_resultscan be used to visualize in the 3D Slicer software. - Running the code
evaluate/evaluate_abdomen.py, and setting the parameterspred_path, andgt_pathtoOUTPUT_DIR/seg_results,Abdomen/RawData/Training/label/, respectively. - The evaluation results will appear in
OUTPUT_DIR/seg_results/dice_pred.txt.
tools_visualize.py: is used to visualize the segmentation results.tools_flops.py: is used to count the flops of models.
If you find this project useful for your research, please cite our paper:
@Article{s23073420,
AUTHOR = {Wei, Chen and Ren, Shenghan and Guo, Kaitai and Hu, Haihong and Liang, Jimin},
TITLE = {High-Resolution Swin Transformer for Automatic Medical Image Segmentation},
JOURNAL = {Sensors},
VOLUME = {23},
YEAR = {2023},
NUMBER = {7},
ARTICLE-NUMBER = {3420},
URL = {https://www.mdpi.com/1424-8220/23/7/3420},
ISSN = {1424-8220},
DOI = {10.3390/s23073420}
}Chen Wei
email: [email protected]