Skip to content

Commit aa17a9f

Browse files
authored
Dfp pclr zoo (#594)
* ENH: Add C3PO PCLR models * ENH: edit get_model() and get_representations() to add the c3po pclr models * STYLE: Update README * FIX: Upload the correct c3po_pclr model
1 parent 440af5f commit aa17a9f

File tree

5 files changed

+37
-5
lines changed

5 files changed

+37
-5
lines changed

model_zoo/PCLR/.gitattributes

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.h5 filter=lfs diff=lfs merge=lfs -text

model_zoo/PCLR/README.md

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ python -i get_representations.py # test the setup worked
2222
You can get ECG representations using [get_representations.py](./get_representations.py).
2323
`get_representations.get_representations` builds `N x 320` ECG representations from `N` ECGs.
2424

25-
The model expects 10s 12-lead ECGs with a specific lead order and interpolated to be 4,096 samples long.
25+
The model expects 10s 12-lead ECGs meaured in milli-volts with a specific lead order and interpolated to be 4,096 samples long.
2626
[preprocess_ecg.py](./preprocess_ecg.py) shows how to do the pre-processing.
2727

2828
### Use git LFS to localize the model file
@@ -103,6 +103,24 @@ the model only takes lead I of the ECG as input.
103103
## Lead II PCLR
104104
[Lead II PCLR](./PCLR_lead_II.h5) is like lead I PCLR except it was trained with all ECGs sampled to 250Hz.
105105

106+
## C3PO PCLR and AUG C3PO PCLR
107+
We also provide PCLR models trained using subjects from the C3PO cohort, with and without augmentation.
108+
The model files are available via:
109+
110+
`git lfs pull --include model_zoo/PCLR/c3po_pclr.h5`
111+
112+
`git lfs pull --include model_zoo/PCLR/aug_c3po_pclr.h5`
113+
114+
You can get ECG representations using for example [get_representations.py(ecgs, model_name='c3po_pclr')](./get_representations.py).
115+
`get_representations.get_representations` builds `N x 320` ECG representations from `N` ECGs.
116+
117+
The model expects 10s 12-lead ECGs measured in milli-volts with a specific lead order and interpolated to be 2,500 samples long. Note that this interpolation is different from the standard PCLR model.
118+
[preprocess_ecg.py](./preprocess_ecg.py) shows how to do the pre-processing; when calling it remember to set `ecg_samples=2500`.
119+
120+
The code snippet above showing example inference with UKB ECGs is also appropriate for these models. Remember to:
121+
1. Load `c3po_pclr.h5` or `aug_c3po_pclr.h5` instead of `PCLR.h5`.
122+
2. Interpolate to 2500 instead of 4096.
123+
106124
## Alternative save format
107125
The newer keras saved model format is available for the 12-lead and single lead models at [PCLR](./PCLR)
108126
and [PCLR_lead_I](./PCLR_lead_I) and [PCLR_lead_II](./PCLR_lead_II).

model_zoo/PCLR/aug_c3po_pclr.h5

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:aea32327873194c38b000a1a0ba25e0b0d7ddbfcc4d68b18c34a423a8fff873d
3+
size 25688728

model_zoo/PCLR/c3po_pclr.h5

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:cef43254d129ea7741b670c868b9423cd140186ee46ba051bc1b9eea5cc7093e
3+
size 25688728

model_zoo/PCLR/get_representations.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,27 @@
66
from preprocess_ecg import process_ecg, LEADS
77

88

9-
def get_model() -> Model:
9+
def get_model(model_name = 'pclr') -> Model:
1010
"""Get PCLR embedding model"""
11-
return load_model("./PCLR.h5")
11+
if model_name == 'pclr':
12+
return load_model("./PCLR.h5")
13+
elif model_name == 'c3po_pclr':
14+
return load_model("./c3po_pclr.h5")
15+
elif model_name == 'aug_c3po_pclr':
16+
return load_model("./aug_c3po_pclr.h5")
1217

1318

14-
def get_representations(ecgs: List[Dict[str, np.ndarray]]) -> np.ndarray:
19+
def get_representations(ecgs: List[Dict[str, np.ndarray]], model_name:str = 'pclr') -> np.ndarray:
1520
"""
1621
Uses PCLR trained model to build representations of ECGs
1722
:param ecgs: A list of dictionaries mapping lead name to lead values.
1823
The lead values should be measured in milli-volts.
1924
Each lead should represent 10s of samples.
25+
:param model_name: Specifies the model to use: either 'pclr', 'c3po_pclr' or 'aug_c3po_pclr'.
26+
Default is 'pclr'
2027
:return:
2128
"""
22-
model = get_model()
29+
model = get_model(model_name)
2330
ecgs = np.stack(list(map(process_ecg, ecgs)))
2431
return model.predict(ecgs)
2532

0 commit comments

Comments
 (0)