Transfer Survival Forest (TSF) is a method designed for small-sample survival analysis. It is based on Random Survival Forest (RSF) and utilizes transfer learning to improve survival prediction accuracy. This project implements TSF training, feature probability calculation, and target forest fine-tuning.
TSF/
│── rsf_models/ # Directory to store trained RSF models
│── data/ # Directory to store datasets (SEER as an example)
│ ├── SEER.csv # Pretrain data example
│── preprocessed_data/ # Directory for preprocessed CSV files
│ ├── SEER_X.csv # SEER features (generated by preprocess_data.py)
│ ├── SEER_y.csv # SEER targets (generated by preprocess_data.py)
│ ├── wch_X.csv # WCH features (generated by preprocess_data.py)
│ └── wch_y.csv # WCH targets (generated by preprocess_data.py)
│── preprocess_data.py # Preprocesses datasets and saves as CSV
│── train_source_forest.py # Trains the source forest and saves the model
│── target_forest_finetune.py # Transfers and fine-tunes the target forest
│── dp_based.py # DP-based target forest training
│── calculate_dp.py # Computes feature probabilities
│── model/
│ ├── TransferSurvivalForest.py # Core TSF implementation
│ ├── TransferTree.py # Core TSF implementation
│ └── methods.py # Preprocessing and utility functions
│── global_names.py # Global variables
│── environment.yml # Conda environment configuration
│── requirements.txt # Python dependencies
│── README.md # Project documentationCreate and activate the conda environment with all required dependencies:
conda env create -f environment.yml
conda activate tsfThen install torch.
- The pretrained dataset and target dataset must be aligned (i.e., they must have the same number of features and feature names).
- The repository provides a pretrain data sample (data/SEER.csv)
- Please put your training cohort in the folder
data.
Before running any experiments, preprocess the datasets to generate CSV files:
python preprocess_data.py- This script:
- Loads and preprocesses both SEER and WCH datasets
- Saves preprocessed data as CSV files in
preprocessed_data/directory - Creates:
SEER_X.csv,SEER_y.csv,wch_X.csv,wch_y.csv - CSV format provides better universality and cross-platform compatibility
Make sure you have completed data preprocessing (or have prepared data in similar format) before running any of the following scripts.
python train_source_forest.py- This script:
- Loads preprocessed SEER dataset from CSV files.
- Trains the Random Survival Forest (RSF) (Source Forest).
- Saves the trained model in
rsf_models/source_forest.pkl. - Computes feature probability and saves it in
dp.csv.
python target_forest_finetune.py- This script:
- Loads the pretrained
source_forest.pkl. - Loads preprocessed WCH target dataset from CSV files.
- Fine-tunes the model on the target dataset.
- Uses 10-fold cross-validation for evaluation.
- Outputs comprehensive metrics: CTD, C-index, and Integrated Brier Score.
- Loads the pretrained
python dp_based.py- This script:
- Loads preprocessed WCH dataset from CSV files.
- Computes and applies Depth Probability (DP) method.
- Trains the RSF with transferred structures.
- Evaluates the model using 10-fold cross-validation.
- Outputs comprehensive metrics: CTD, C-index, and Integrated Brier Score.
If you use this code in your research, please cite the following paper: [Zhao, Y., Li, C., Shu, C., Wu, Q., Li, H., Xu, C., Li, T., Wang, Z., Luo, Z., & He, Y. (2025). Tackling small sample survival analysis via transfer learning: A study of colorectal cancer prognosis. arXiv preprint arXiv:2501.12421. https://arxiv.org/abs/2501.12421]
For questions or suggestions, please reach out to: sc22yz3@leeds.ac.uk
