git clone https://github.com/ximosss/CM-SSL.git
cd CM-SSL
conda create -n cmssl python=3.11
conda activate cmssl
pip install -r requirements.txt
Pretraining dataset
We use PulseDB as ours pretraining dataset. More info please refer to PulseDB paper and offical Repository
Finetuning dataset
downstream datasets we used are as follows:
PPG part
- UCI: The Cuff-Less Blood Pressure Estimation Dataset from the UCI Machine Learning Repository
- BIDMC: BIDMC PPG and Respiration Dataset
- PPG_DALIA: PPG-DaLiA contains data from 15 subjects wearing physiological and motion sensors, providing a PPG dataset for motion compensation and heart rate estimation in Daily Life Activities.
ECG part
- CSN: A large scale 12-lead electrocardiogram database for arrhythmia study
- Physionet2017: AF Classification from a Short Single Lead ECG Recording: The PhysioNet/Computing in Cardiology Challenge 2017
- PTB-XL: a large publicly available electrocardiography dataset
The workflow is divided into three main steps: pre-training the tokenizer, pre-training the encoder, and fine-tuning on downstream tasks.
-
Pre-train the Joint Tokenizer Run the tokenizer pre-training script. This will learn the joint vocabulary from synchronous ECG-PPG pairs and save the trained tokenizer model.
python3 -m src.app.run_jtwbio_tokenizer_pretraining --config src/configs/jtwbio_tokenizer_config.yaml -
Pre-train the Unimodal Encoder Using the trained tokenizer from the previous step, pre-train the unimodal encoder. This encoder will learn to predict the joint tokens from a single modality.
python3 -m src.app.run_jtwbio_encoder_pretraining fit --config src/configs/jtwbio_encoder_config.yaml -
Fine-tune on a Downstream Task Fine-tune the pre-trained unimodal encoder on a specific downstream task, such as HR estimation or ECG classification. Note that the input data for this stage is unimodal.
torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:29500 -m src.modules.encoder_finetuningor using bash script instead
./run_finetuning.sh -m "encoder" -e "ppg_dalia_rr ppg_dalia_hr bidmc_rr bidmc_hr uci_sbp uci_dbp"