Skip to content

Commit 5c3c5e5

Browse files
committed
add start dimension
1 parent cf5a856 commit 5c3c5e5

File tree

1 file changed

+62
-33
lines changed

1 file changed

+62
-33
lines changed

ml4h/tensor_generators.py

Lines changed: 62 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,37 +1180,61 @@ def df_to_datasets_from_generator(df, INPUT_NUMERIC_COLS, input_categorical_colu
11801180
# Log number of groups found
11811181
logging.info(f"Found {len(group_ids)} groups (unique {AGGREGATE_COLUMN}s)")
11821182

1183-
# Read MRNs from CSV files
1184-
train_mrns = _sample_csv_to_set(train_csv) if train_csv else set()
1185-
valid_mrns = _sample_csv_to_set(valid_csv) if valid_csv else set()
1186-
test_mrns = _sample_csv_to_set(test_csv) if test_csv else set()
1187-
11881183
# Get unique MRNs from the dataframe
11891184
unique_mrns = df_sorted['mrn'].drop_duplicates().to_numpy()
11901185
logging.info(f"Found {len(unique_mrns)} unique MRNs in dataframe")
1191-
logging.info(f"CSV files contain: {len(train_mrns)} train MRNs, {len(valid_mrns)} valid MRNs, {len(test_mrns)} test MRNs")
1192-
1193-
# Log sample MRNs for debugging
1194-
if len(unique_mrns) > 0:
1195-
logging.info(f"Sample dataframe MRNs: {list(unique_mrns[:3])}")
1196-
if len(train_mrns) > 0:
1197-
logging.info(f"Sample train CSV MRNs: {list(list(train_mrns)[:3])}")
1198-
1199-
# Split MRNs into train/val/test based on CSV membership
1200-
train_mrn_set = set()
1201-
val_mrn_set = set()
1202-
test_mrn_set = set()
1203-
1204-
for mrn in unique_mrns:
1205-
mrn_str = str(mrn)
1206-
if train_mrns and mrn_str in train_mrns:
1207-
train_mrn_set.add(mrn)
1208-
elif valid_mrns and mrn_str in valid_mrns:
1209-
val_mrn_set.add(mrn)
1210-
elif test_mrns and mrn_str in test_mrns:
1211-
test_mrn_set.add(mrn)
1212-
1213-
logging.info(f"Matched MRNs: {len(train_mrn_set)} train, {len(val_mrn_set)} valid, {len(test_mrn_set)} test")
1186+
1187+
# Check if CSV files are provided
1188+
if train_csv or valid_csv or test_csv:
1189+
# Read MRNs from CSV files
1190+
train_mrns = _sample_csv_to_set(train_csv) if train_csv else set()
1191+
valid_mrns = _sample_csv_to_set(valid_csv) if valid_csv else set()
1192+
test_mrns = _sample_csv_to_set(test_csv) if test_csv else set()
1193+
1194+
logging.info(f"CSV files contain: {len(train_mrns)} train MRNs, {len(valid_mrns)} valid MRNs, {len(test_mrns)} test MRNs")
1195+
1196+
# Log sample MRNs for debugging
1197+
if len(unique_mrns) > 0:
1198+
logging.info(f"Sample dataframe MRNs: {list(unique_mrns[:3])}")
1199+
if len(train_mrns) > 0:
1200+
logging.info(f"Sample train CSV MRNs: {list(list(train_mrns)[:3])}")
1201+
1202+
# Split MRNs into train/val/test based on CSV membership
1203+
train_mrn_set = set()
1204+
val_mrn_set = set()
1205+
test_mrn_set = set()
1206+
1207+
for mrn in unique_mrns:
1208+
mrn_str = str(mrn)
1209+
if train_mrns and mrn_str in train_mrns:
1210+
train_mrn_set.add(mrn)
1211+
elif valid_mrns and mrn_str in valid_mrns:
1212+
val_mrn_set.add(mrn)
1213+
elif test_mrns and mrn_str in test_mrns:
1214+
test_mrn_set.add(mrn)
1215+
1216+
logging.info(f"Matched MRNs: {len(train_mrn_set)} train, {len(val_mrn_set)} valid, {len(test_mrn_set)} test")
1217+
else:
1218+
# No CSV files provided - randomly split MRNs: 80% train, 10% valid, 10% test
1219+
logging.info("No CSV files provided. Randomly splitting MRNs: 80% train, 10% valid, 10% test")
1220+
1221+
from sklearn.model_selection import train_test_split
1222+
1223+
# First split: 80% train, 20% temp (for valid+test)
1224+
train_mrns_arr, temp_mrns = train_test_split(
1225+
unique_mrns, test_size=0.2, random_state=42
1226+
)
1227+
1228+
# Second split: split temp into 50% valid, 50% test (each 10% of total)
1229+
val_mrns_arr, test_mrns_arr = train_test_split(
1230+
temp_mrns, test_size=0.5, random_state=42
1231+
)
1232+
1233+
train_mrn_set = set(train_mrns_arr)
1234+
val_mrn_set = set(val_mrns_arr)
1235+
test_mrn_set = set(test_mrns_arr)
1236+
1237+
logging.info(f"Random split MRNs: {len(train_mrn_set)} train, {len(val_mrn_set)} valid, {len(test_mrn_set)} test")
12141238

12151239
# Now map group_ids to train/val/test based on their MRN
12161240
# Build a mapping from group_id to mrn
@@ -1252,11 +1276,16 @@ def df_to_datasets_from_generator(df, INPUT_NUMERIC_COLS, input_categorical_colu
12521276

12531277
# Validate that we have at least some data in train set
12541278
if train_groups == 0:
1255-
raise ValueError(
1256-
f"Training set is empty! No MRNs from train_csv matched the dataframe. "
1257-
f"Check that MRN formats match between CSV and dataframe. "
1258-
f"Dataframe has {len(unique_mrns)} unique MRNs, train_csv has {len(train_mrns)} MRNs."
1259-
)
1279+
if train_csv or valid_csv or test_csv:
1280+
raise ValueError(
1281+
f"Training set is empty! No MRNs from CSV files matched the dataframe. "
1282+
f"Check that MRN formats match between CSV and dataframe. "
1283+
f"Dataframe has {len(unique_mrns)} unique MRNs."
1284+
)
1285+
else:
1286+
raise ValueError(
1287+
f"Training set is empty! Dataframe has {len(unique_mrns)} unique MRNs but none were assigned to training."
1288+
)
12601289

12611290
Feat = len(INPUT_NUMERIC_COLS)
12621291

0 commit comments

Comments
 (0)