@@ -1180,37 +1180,61 @@ def df_to_datasets_from_generator(df, INPUT_NUMERIC_COLS, input_categorical_colu
11801180 # Log number of groups found
11811181 logging .info (f"Found { len (group_ids )} groups (unique { AGGREGATE_COLUMN } s)" )
11821182
1183- # Read MRNs from CSV files
1184- train_mrns = _sample_csv_to_set (train_csv ) if train_csv else set ()
1185- valid_mrns = _sample_csv_to_set (valid_csv ) if valid_csv else set ()
1186- test_mrns = _sample_csv_to_set (test_csv ) if test_csv else set ()
1187-
11881183 # Get unique MRNs from the dataframe
11891184 unique_mrns = df_sorted ['mrn' ].drop_duplicates ().to_numpy ()
11901185 logging .info (f"Found { len (unique_mrns )} unique MRNs in dataframe" )
1191- logging .info (f"CSV files contain: { len (train_mrns )} train MRNs, { len (valid_mrns )} valid MRNs, { len (test_mrns )} test MRNs" )
1192-
1193- # Log sample MRNs for debugging
1194- if len (unique_mrns ) > 0 :
1195- logging .info (f"Sample dataframe MRNs: { list (unique_mrns [:3 ])} " )
1196- if len (train_mrns ) > 0 :
1197- logging .info (f"Sample train CSV MRNs: { list (list (train_mrns )[:3 ])} " )
1198-
1199- # Split MRNs into train/val/test based on CSV membership
1200- train_mrn_set = set ()
1201- val_mrn_set = set ()
1202- test_mrn_set = set ()
1203-
1204- for mrn in unique_mrns :
1205- mrn_str = str (mrn )
1206- if train_mrns and mrn_str in train_mrns :
1207- train_mrn_set .add (mrn )
1208- elif valid_mrns and mrn_str in valid_mrns :
1209- val_mrn_set .add (mrn )
1210- elif test_mrns and mrn_str in test_mrns :
1211- test_mrn_set .add (mrn )
1212-
1213- logging .info (f"Matched MRNs: { len (train_mrn_set )} train, { len (val_mrn_set )} valid, { len (test_mrn_set )} test" )
1186+
1187+ # Check if CSV files are provided
1188+ if train_csv or valid_csv or test_csv :
1189+ # Read MRNs from CSV files
1190+ train_mrns = _sample_csv_to_set (train_csv ) if train_csv else set ()
1191+ valid_mrns = _sample_csv_to_set (valid_csv ) if valid_csv else set ()
1192+ test_mrns = _sample_csv_to_set (test_csv ) if test_csv else set ()
1193+
1194+ logging .info (f"CSV files contain: { len (train_mrns )} train MRNs, { len (valid_mrns )} valid MRNs, { len (test_mrns )} test MRNs" )
1195+
1196+ # Log sample MRNs for debugging
1197+ if len (unique_mrns ) > 0 :
1198+ logging .info (f"Sample dataframe MRNs: { list (unique_mrns [:3 ])} " )
1199+ if len (train_mrns ) > 0 :
1200+ logging .info (f"Sample train CSV MRNs: { list (list (train_mrns )[:3 ])} " )
1201+
1202+ # Split MRNs into train/val/test based on CSV membership
1203+ train_mrn_set = set ()
1204+ val_mrn_set = set ()
1205+ test_mrn_set = set ()
1206+
1207+ for mrn in unique_mrns :
1208+ mrn_str = str (mrn )
1209+ if train_mrns and mrn_str in train_mrns :
1210+ train_mrn_set .add (mrn )
1211+ elif valid_mrns and mrn_str in valid_mrns :
1212+ val_mrn_set .add (mrn )
1213+ elif test_mrns and mrn_str in test_mrns :
1214+ test_mrn_set .add (mrn )
1215+
1216+ logging .info (f"Matched MRNs: { len (train_mrn_set )} train, { len (val_mrn_set )} valid, { len (test_mrn_set )} test" )
1217+ else :
1218+ # No CSV files provided - randomly split MRNs: 80% train, 10% valid, 10% test
1219+ logging .info ("No CSV files provided. Randomly splitting MRNs: 80% train, 10% valid, 10% test" )
1220+
1221+ from sklearn .model_selection import train_test_split
1222+
1223+ # First split: 80% train, 20% temp (for valid+test)
1224+ train_mrns_arr , temp_mrns = train_test_split (
1225+ unique_mrns , test_size = 0.2 , random_state = 42
1226+ )
1227+
1228+ # Second split: split temp into 50% valid, 50% test (each 10% of total)
1229+ val_mrns_arr , test_mrns_arr = train_test_split (
1230+ temp_mrns , test_size = 0.5 , random_state = 42
1231+ )
1232+
1233+ train_mrn_set = set (train_mrns_arr )
1234+ val_mrn_set = set (val_mrns_arr )
1235+ test_mrn_set = set (test_mrns_arr )
1236+
1237+ logging .info (f"Random split MRNs: { len (train_mrn_set )} train, { len (val_mrn_set )} valid, { len (test_mrn_set )} test" )
12141238
12151239 # Now map group_ids to train/val/test based on their MRN
12161240 # Build a mapping from group_id to mrn
@@ -1252,11 +1276,16 @@ def df_to_datasets_from_generator(df, INPUT_NUMERIC_COLS, input_categorical_colu
12521276
12531277 # Validate that we have at least some data in train set
12541278 if train_groups == 0 :
1255- raise ValueError (
1256- f"Training set is empty! No MRNs from train_csv matched the dataframe. "
1257- f"Check that MRN formats match between CSV and dataframe. "
1258- f"Dataframe has { len (unique_mrns )} unique MRNs, train_csv has { len (train_mrns )} MRNs."
1259- )
1279+ if train_csv or valid_csv or test_csv :
1280+ raise ValueError (
1281+ f"Training set is empty! No MRNs from CSV files matched the dataframe. "
1282+ f"Check that MRN formats match between CSV and dataframe. "
1283+ f"Dataframe has { len (unique_mrns )} unique MRNs."
1284+ )
1285+ else :
1286+ raise ValueError (
1287+ f"Training set is empty! Dataframe has { len (unique_mrns )} unique MRNs but none were assigned to training."
1288+ )
12601289
12611290 Feat = len (INPUT_NUMERIC_COLS )
12621291
0 commit comments