|
| 1 | +# ECG2Age — Age Prediction from 12‑Lead ECG (Keras) |
| 2 | + |
| 3 | +A lightweight Keras/TensorFlow model that predicts chronological age from a standard 12‑lead ECG. |
| 4 | + |
| 5 | +## Table of Contents |
| 6 | + |
| 7 | +* [Model card](#model-card) |
| 8 | + |
| 9 | + * [Inputs](#inputs) |
| 10 | + * [Outputs](#outputs) |
| 11 | + * [Performance](#performance) |
| 12 | +* [Getting started](#getting-started) |
| 13 | + |
| 14 | + * [Environment](#environment) |
| 15 | + * [Repository layout](#repository-layout) |
| 16 | +* [Model files](#model-files) |
| 17 | +* [Load & run inference](#load--run-inference) |
| 18 | + |
| 19 | +--- |
| 20 | + |
| 21 | +## Model card |
| 22 | + |
| 23 | +### Inputs |
| 24 | + |
| 25 | +* **Modality**: 12‑lead resting ECG |
| 26 | +* **Expected shape**: `(batch, time, leads)` = `(B, 5000, 12)` by default |
| 27 | + |
| 28 | + * Sampling rate: **500 Hz** (10 seconds ⇒ 5000 samples) |
| 29 | + * If your data is `(B, 12, 5000)`, set `--channels_first` or transpose before feeding the model. |
| 30 | +* **Dtype / range**: `float32` normalized per‑lead (e.g., z‑score or min‑max). |
| 31 | + Provide your normalization in `data/processing.py`. |
| 32 | + |
| 33 | +> Update the shape/sampling rate above if your dataset differs. |
| 34 | +
|
| 35 | +### Outputs |
| 36 | + |
| 37 | +* **Target**: chronological age in **years** (scalar per record) |
| 38 | +* **Prediction head**: single linear unit with optional activation clipping (e.g., `ReLU` at 0 years) |
| 39 | +* **Loss**: MAE or Huber (configurable) |
| 40 | + |
| 41 | +### Performance |
| 42 | + |
| 43 | +Please benchmark on a **held‑out test set** (subject‑wise split). Example table with placeholders: |
| 44 | + |
| 45 | +| Metric | Test | |
| 46 | +| ----------- | ---------------: | |
| 47 | +| R² | 0.3056011687 | |
| 48 | +| Pearson r | 0.5528120555 | |
| 49 | + |
| 50 | +Add a calibration plot and error vs age plot if possible (see `notebooks/metrics.ipynb`). |
| 51 | + |
| 52 | +--- |
| 53 | + |
| 54 | +## Getting started |
| 55 | + |
| 56 | +### Environment |
| 57 | + |
| 58 | +```bash |
| 59 | +# Python ≥3.10 recommended |
| 60 | +python -m venv .venv |
| 61 | +source .venv/bin/activate |
| 62 | +pip install -U pip |
| 63 | +pip install -r requirements.txt |
| 64 | +# or |
| 65 | +pip install tensorflow==2.16.* keras==3.* numpy scipy scikit-learn matplotlib h5py |
| 66 | +``` |
| 67 | + |
| 68 | +### Repository layout |
| 69 | + |
| 70 | +``` |
| 71 | +. |
| 72 | +├── model_zoo/ECG2Age |
| 73 | +│ ├── ECG2Age.keras # Keras SavedModel (via Keras 3) |
| 74 | + ├── encoder_median.keras # Encoder part of ECG2Age model |
| 75 | + ├── decoder_instance_age.keras # Decoder part of ECG2Age model |
| 76 | + ├── merger.keras #Encoder decoder merger intermediate model |
| 77 | +
|
| 78 | +``` |
| 79 | + |
| 80 | +--- |
| 81 | + |
| 82 | +## Load & run inference |
| 83 | + |
| 84 | +### Python API |
| 85 | + |
| 86 | +```python |
| 87 | +import numpy as np |
| 88 | +import tensorflow as tf |
| 89 | +from keras import ops |
| 90 | + |
| 91 | +# Load model (Keras 3 format) |
| 92 | +model = tf.keras.models.load_model("model_zoo/ECG2Age/ECG2Age.keras", compile=False) |
| 93 | + |
| 94 | +# Example input: batch of 8 ECGs, 5000 samples, 12 leads |
| 95 | +x = np.random.randn(8, 5000, 12).astype("float32") |
| 96 | +# Optional: apply same normalization used during training |
| 97 | +# x = normalize_batch(x) |
| 98 | + |
| 99 | +pred_age = model.predict(x, verbose=0) # shape (8, 1) |
| 100 | +print(pred_age.squeeze()) |
| 101 | +``` |
0 commit comments