Skip to content

ML-15 Add pre/post processing ECG2AF image #592

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
Empty file.
7 changes: 7 additions & 0 deletions model_zoo/ECG2AF/deployment/v1/processing_image/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
FROM python:3.9-slim
WORKDIR /app
COPY prepare.py /app/
COPY finalize.py /app/
COPY requirements.txt /app/
RUN pip install -r /app/requirements.txt
ENTRYPOINT ["python"]
44 changes: 44 additions & 0 deletions model_zoo/ECG2AF/deployment/v1/processing_image/finalize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import argparse
import json
import numpy as np
import pandas as pd


def convert_survival_curve_to_risk_score(curve):
curve = np.array(curve)
return 1 - np.cumprod(curve[:25])[-1]


def finalize(input_csv, predictions_json, output_csv):
with open(predictions_json, "r") as f:
prediction_data = json.load(f)

df = pd.read_csv(input_csv, dtype={"file_id": str})

age = prediction_data["output_age_from_wide_csv_continuous"]
af = prediction_data["output_af_in_read_categorical"]
sex = prediction_data["output_sex_from_wide_categorical"]
curves = prediction_data["output_survival_curve_af_survival_curve"]

if len(age) != len(df):
raise ValueError(f"Mismatch: {len(age)} predictions but {len(df)} rows in input CSV!")

df["output_age"] = [row[0] for row in age]
df["output_af_0"] = [row[0] for row in af]
df["output_af_1"] = [row[1] for row in af]
df["output_sex_male"] = [row[0] for row in sex]
df["output_sex_female"] = [row[1] for row in sex]
df["af_risk_score"] = [convert_survival_curve_to_risk_score(row) for row in curves]

df.to_csv(output_csv, index=False)
print(f"✅ Predictions written to {output_csv} ({len(df)} rows).")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input", required=True, help="Path to input CSV")
parser.add_argument("--output", required=True, help="Path to final CSV with predictions")
parser.add_argument("--predictions", required=True, help="Path to predictions JSON")
args = parser.parse_args()

finalize(args.input, args.predictions, args.output)
53 changes: 53 additions & 0 deletions model_zoo/ECG2AF/deployment/v1/processing_image/prepare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import argparse

import h5py
import numpy as np
import pandas as pd
import smart_open

ECG_REST_LEADS = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is ok for now, but since these constants are already defined in ml4h/defines.py it would be better to include ml4h and import from there. But that would require making the docker image much bigger.

'strip_I': 0, 'strip_II': 1, 'strip_III': 2, 'strip_V1': 6, 'strip_V2': 7, 'strip_V3': 8,
'strip_V4': 9, 'strip_V5': 10, 'strip_V6': 11, 'strip_aVF': 5, 'strip_aVL': 4, 'strip_aVR': 3,
}
ECG_SHAPE = (5000, 12)
ECG_HD5_PATH = 'ukb_ecg_rest'


def ecg_as_tensor(ecg_file):
with smart_open.open(ecg_file, 'rb') as f:
with h5py.File(f, 'r') as hd5:
tensor = np.zeros(ECG_SHAPE, dtype=np.float32)
for lead in ECG_REST_LEADS:
data = np.array(hd5[f'{ECG_HD5_PATH}/{lead}/instance_0'])
tensor[:, ECG_REST_LEADS[lead]] = data

mean = np.mean(tensor)
std = np.std(tensor) + 1e-7
tensor = (tensor - mean) / std
return tensor


def prepare(input_csv, output_h5):
"""Processes ECG files into HDF5 tensor format from GCS/Azure/Local."""
df = pd.read_csv(input_csv, dtype={"file": str})
h5_file = h5py.File(output_h5, "w")
tensors_group = h5_file.create_group("tensors")
df = df.dropna(subset=["file"])
df["file"] = df["file"].astype(str)
for _, row in df.iterrows():
sample_id, file_path = row["file_id"], row["file"]
print(f"Processing: sample_id={sample_id}, file_path={file_path}, type={type(file_path)}")
tensor = ecg_as_tensor(file_path)
tensors_group.create_dataset(str(sample_id), data=tensor)

h5_file.close()
print(f"Processed ECG tensors saved to {output_h5}")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input", required=True, help="Path to input CSV")
parser.add_argument("--output", required=True, help="Path to output HDF5 file")
args = parser.parse_args()

prepare(args.input, args.output)
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
pandas
numpy
h5py
smart-open[gcs]