-
Notifications
You must be signed in to change notification settings - Fork 0
David/miRBind cnn optimization #7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
7eb1cef
e92e058
63f0558
410cf03
2fdcc3d
094359b
cf62cce
3a018ad
c4bafe4
f9299c1
f7d0f92
05f36df
ffff74c
4f9bd9c
ebfc315
73fb66a
ad84af1
8665929
a475694
b6dddc4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| .ipynb_checkpoints | ||
| .keras | ||
| *.keras |
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is nice to add some example into README, that shown how to run this whole analysis. Either sequence of call to different scripts. Or if you have the whole pipeline, can to call the pipeline script. And what are expected inputs and outputs. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| # mirBind Model optimisation pipeline | ||
|
|
||
| This repository contains scripts for training and evaluating a deep learning model based on variations and tuning of the miRBind architecture for miRNA-binding prediction using eCLIP data from Manakov | ||
|
|
||
| 1. ../encode_dataset.sh | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there is no |
||
| Converts AGO2 eCLIP datasets from Manakov2022 into the 2D matrix format. | ||
| Before running the script, data needs to be place in this directory (same for test and leftout dataset): "miRBind_2.0/data/chimeric_datasets/Manakov2022_flat/AGO2_eCLIP_Manakov2022_train.tsv" | ||
|
|
||
| 2. hyperparam_optimization.sh | ||
| Performs hyperparameter optimization for the model. Saves the best model checkpoint, architecture description, training stats, and metrics. | ||
|
|
||
| 3. train_model.sh | ||
| Trains the model using the optimized hyperparameters until convergence. Saves model checkpoints and training results. | ||
| Requires setting the name (timestamp) for your model | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you meant it requires setting the name for your model? It looks like the train_model.sh script does it itself. |
||
|
|
||
| 4. evaluate_model.sh | ||
| Evaluates the trained model on test and left-out datasets. Requires setting the name (timestamp) of your trained model. Generates performance metrics and plots. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we have this as input param of the script? And the previous script could print the name of the best model. |
||
| Saves results. | ||
|
|
||
| ../hyperparam_optimization_pipeline.sh | ||
| Orchestrates the (almost) entire workflow (except training until convergence with found hyperpara.) in a single execution, combining dataset encoding, hyperparameter optimization, and model evaluation. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is "except training until convergence with found hyperpara." step missing? |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| #!/bin/bash | ||
|
|
||
| test_file_out="encoded_dataset/Manakov2022_flat/AGO2_eCLIP_Manakov2022_test" | ||
| leftout_file_out="encoded_dataset/Manakov2022_flat/AGO2_eCLIP_Manakov2022_leftout" | ||
|
Comment on lines
+3
to
+4
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These could be input params |
||
|
|
||
| # set the model_name to how you named your run | ||
| timestamp=#TODO_SET_YOUR_TRAINED_MODEL'S_TIMESTAMP | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could be input param as well. Or maybe even whole path to the trained model. In theory, you could use this script to evaluate any model |
||
| model_name="mirBind_${timestamp}" | ||
| best_model_path="evaluation_results/${model_name}/${model_name}_final.keras" | ||
| evaluation_out_dir="evaluation_results/${model_name}" | ||
|
|
||
| python ../../../code/machine_learning/evaluate/evaluate_model.py \ | ||
| --model-path $best_model_path \ | ||
| --dataset-test "../${test_file_out}_dataset.npy" \ | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not set the whole name to the input data in input param? like this, you are hardcoding the naming structure here. |
||
| --labels-test "../${test_file_out}_labels.npy" \ | ||
| --batch-size 32 \ | ||
| --log-file "model_evaluation_test.log" \ | ||
| --save-plots \ | ||
| --output-dir $evaluation_out_dir | ||
|
|
||
| python ../../../code/machine_learning/evaluate/evaluate_model.py \ | ||
| --model-path $best_model_path \ | ||
| --dataset-test "../${leftout_file_out}_dataset.npy" \ | ||
| --labels-test "../${leftout_file_out}_labels.npy" \ | ||
| --batch-size 32 \ | ||
| --log-file "model_evaluation_leftout.log" \ | ||
| --save-plots \ | ||
| --output-dir $evaluation_out_dir | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,158 @@ | ||
| import argparse | ||
| import numpy as np | ||
| import logging | ||
| import optuna | ||
| import optuna.visualization as vis | ||
| from optuna.integration import TFKerasPruningCallback | ||
| from tensorflow import keras as K | ||
| from tensorflow.keras.optimizers import Adam | ||
| from tensorflow.keras.utils import Sequence | ||
| import tensorflow as tf | ||
| import random | ||
|
|
||
| from code.machine_learning.utils import set_seeds, compile_model | ||
| from code.machine_learning.data_generators import TrainDataGenerator | ||
| import sys | ||
| sys.path.append("../../../code/machine_learning/train/CNN_miRBind_2022/") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is it possible to import |
||
| from miRBind_CNN_architecture import miRBind_CNN | ||
|
|
||
|
|
||
| def objective(trial, train_data_gen, val_data_gen, dataset_ratio, best_model_path, epochs): | ||
| global best_model, best_val_auc | ||
|
|
||
| K.backend.clear_session() | ||
|
|
||
| cnn_num = trial.suggest_int('cnn_layers_num', 2, 10) | ||
| kernel_size = trial.suggest_int('kernel_size', 3, 10) | ||
| pool_size = trial.suggest_int('pool_size', 1, 8) | ||
| dense_num = trial.suggest_int('dense_layers_num', 2, cnn_num) | ||
| model = miRBind_CNN(cnn_num=cnn_num, kernel_size=kernel_size, pool_size=pool_size, dense_num=dense_num).model | ||
| lr = trial.suggest_float('learning_rate', 0.00001, 0.0001) | ||
| model = compile_model(model, lr=lr) | ||
|
|
||
| model_history = model.fit( | ||
| train_data_gen, | ||
| validation_data=val_data_gen, | ||
| epochs=epochs, | ||
| class_weight={0: 1, 1: dataset_ratio}, | ||
| callbacks=[ | ||
| TFKerasPruningCallback(trial, "val_auc"), | ||
| K.callbacks.EarlyStopping(patience=3, restore_best_weights=True) | ||
| ], | ||
| ) | ||
|
|
||
| num_epochs_trained = np.argmax(model_history.history['val_auc']) | ||
| val_auc = model_history.history['val_auc'][num_epochs_trained] | ||
|
|
||
| if val_auc > best_val_auc: | ||
| best_val_auc = val_auc | ||
| best_model = model | ||
| model.save(best_model_path) | ||
| logger.info(f"New best model found and saved with Validation AUC: {val_auc}") | ||
|
|
||
| print(f"Validation AU PRC: {val_auc}") | ||
| return val_auc | ||
|
|
||
|
|
||
| def setup_logger(log_file): | ||
| logger = logging.getLogger('optuna') | ||
| logger.setLevel(logging.INFO) | ||
| file_handler = logging.FileHandler(log_file, 'w') | ||
| file_handler.setFormatter(logging.Formatter('%(asctime)s - %(message)s')) | ||
| logger.addHandler(file_handler) | ||
| return logger | ||
|
|
||
|
|
||
| def main(): | ||
| parser = argparse.ArgumentParser(description='Hyperparameter optimization for miRBind CNN model') | ||
| parser.add_argument('--dataset-train', type=str, | ||
| default='../encoded_dataset/Manakov2022_flat/AGO2_eCLIP_Manakov2022_train_dataset.npy', | ||
| help='Path to the train dataset') | ||
| parser.add_argument('--labels-train', type=str, | ||
| default='../encoded_dataset/Manakov2022_flat/AGO2_eCLIP_Manakov2022_train_labels.npy', | ||
| help='Path to the train labels') | ||
| parser.add_argument('--dataset-size', type=int, default=2516195, | ||
| help='Size of the dataset') | ||
| parser.add_argument('--dataset-ratio', type=float, default=1, | ||
| help='Dataset ratio for class weighting') | ||
| parser.add_argument('--batch-size', type=int, default=32, | ||
| help='Batch size for training') | ||
| parser.add_argument('--validation-split', type=float, default=0.1, | ||
| help='Validation split ratio') | ||
| parser.add_argument('--n-trials', type=int, default=20, | ||
| help='Number of optimization trials') | ||
| parser.add_argument('--epochs', type=int, default=5, | ||
| help='Number of max epochs per model') | ||
| parser.add_argument('--best-model', type=str, default='best_model.log', | ||
| help='Path to the model trained with optimised hyperparameters') | ||
| parser.add_argument('--log-file', type=str, default='hyperparam_optimization.log', | ||
| help='Path to the log file') | ||
| parser.add_argument('--save-plots', action='store_true', | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we also specify where plots will be saved? |
||
| help='Save optimization plots') | ||
| parser.add_argument('--seed', type=int, default=42, | ||
| help='Random seed for reproducibility') | ||
| args = parser.parse_args() | ||
|
|
||
| # Set seeds for reproducibility | ||
| set_seeds(args.seed) | ||
|
|
||
| global logger, best_model, best_val_auc | ||
| logger = setup_logger(args.log_file) | ||
| best_model = None | ||
| best_val_auc = 0 | ||
|
|
||
| logger.info(f"Starting optimization with seed: {args.seed}") | ||
|
|
||
| train_data_gen = TrainDataGenerator( | ||
| args.dataset_train, | ||
| args.labels_train, | ||
| dataset_size=args.dataset_size, | ||
| batch_size=args.batch_size, | ||
| validation_split=args.validation_split, | ||
| is_validation=False | ||
| ) | ||
|
|
||
| val_data_gen = TrainDataGenerator( | ||
| args.dataset_train, | ||
| args.labels_train, | ||
| dataset_size=args.dataset_size, | ||
| batch_size=args.batch_size, | ||
| validation_split=args.validation_split, | ||
| is_validation=True | ||
| ) | ||
|
|
||
| # Set seed for Optuna | ||
| optuna_sampler = optuna.samplers.TPESampler(seed=args.seed) | ||
| study = optuna.create_study( | ||
| direction='maximize', | ||
| study_name='miRBind_CNN', | ||
| sampler=optuna_sampler | ||
| ) | ||
|
|
||
| study.optimize( | ||
| lambda trial: objective(trial, train_data_gen, val_data_gen, args.dataset_ratio, args.best_model, args.epochs), | ||
| n_trials=args.n_trials | ||
| ) | ||
|
|
||
| logger.info("\n") | ||
| logger.info(f"Best hyperparameters: {study.best_params}") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we save best params to some file maybe? |
||
| logger.info(f"Best value (validation AU PRC): {study.best_value}") | ||
|
|
||
| if args.save_plots: | ||
| plots = { | ||
| 'optimization_history': vis.plot_optimization_history, | ||
| 'contour': vis.plot_contour, | ||
| 'param_importances': vis.plot_param_importances, | ||
| 'slice': vis.plot_slice | ||
| } | ||
|
|
||
| for name, plot_func in plots.items(): | ||
| try: | ||
| fig = plot_func(study) | ||
| fig.write_image(f"{name}.png") | ||
| except Exception as e: | ||
| logger.error(f"Failed to save {name} plot: {str(e)}") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| #!/bin/bash | ||
|
|
||
|
|
||
| train_file="encoded_dataset/Manakov2022_flat/AGO2_eCLIP_Manakov2022_train" | ||
|
|
||
| train_file_size=2516195 | ||
| model_dir="mirBind_2002" | ||
| best_model_path="models/${model_dir}/best_model.keras" | ||
| evaluation_out_dir="evaluation_results/${model_dir}" | ||
|
|
||
|
|
||
| # run hyper parameter optimisation | ||
| python hyperparam_optimization.py \ | ||
| --dataset-train "../${train_file}_dataset.npy" \ | ||
| --labels-train "../${train_file}_labels.npy" \ | ||
| --dataset-size $train_file_size \ | ||
| --dataset-ratio 1 \ | ||
| --batch-size 32 \ | ||
| --validation-split 0.1 \ | ||
| --n-trials 30 \ | ||
| --best-model "$best_model_path" \ | ||
| --log-file "${evaluation_out_dir}/hyperparam_optimization.log" | ||
| --seed 42 | ||
| --epochs 8 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| #!/bin/bash | ||
|
|
||
| # set the model_name to a unique name for your run | ||
| timestamp=$(date +"%Y%m%d_%H%M%S") | ||
| model_name="mirBind_${timestamp}" | ||
| out_dir="evaluation_results/${model_name}" | ||
|
|
||
| train_file="encoded_dataset/Manakov2022_flat/AGO2_eCLIP_Manakov2022_train" | ||
| train_file_size=2516195 | ||
|
|
||
| CODE="../../code/machine_learning" | ||
|
|
||
| python $CODE/train/train_model.py \ | ||
| --dataset-train "../${train_file}_dataset.npy" \ | ||
| --labels-train "../${train_file}_labels.npy" \ | ||
| --dataset-size $train_file_size \ | ||
| --cnn-num 2 \ | ||
| --kernel-size 6 \ | ||
| --pool-size 2 \ | ||
| --dropout-rate 0.3 \ | ||
| --dense-num 2 \ | ||
| --learning-rate 0.00008241877487855944 \ | ||
| --batch-size 32 \ | ||
| --epochs 100 \ | ||
| --patience 6 \ | ||
| --output-dir $out_dir \ | ||
| --model-name $model_name \ | ||
| --seed 42 | ||
|
Comment on lines
+13
to
+28
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be nice to load these values from some file created by hyperpar_tuning.sh. |
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| #!/bin/bash | ||
|
|
||
| timestamp=$(date +"%Y%m%d_%H%M%S") | ||
| model_name="mirBind_${timestamp}" | ||
| best_model_path="models/${model_name}.keras" | ||
| evaluation_out_dir="evaluation_results/${model_name}_hyperopt" | ||
|
|
||
| train_file_in="../../data/chimeric_datasets/Manakov2022_flat/AGO2_eCLIP_Manakov2022_train.tsv" | ||
| test_file_in="../../data/chimeric_datasets/Manakov2022_flat/AGO2_eCLIP_Manakov2022_test.tsv" | ||
| leftout_file_in="../../data/chimeric_datasets/Manakov2022_flat/AGO2_eCLIP_Manakov2022_leftout.tsv" | ||
|
|
||
| train_file_out="encoded_dataset/Manakov2022_flat/AGO2_eCLIP_Manakov2022_train" | ||
| test_file_out="encoded_dataset/Manakov2022_flat/AGO2_eCLIP_Manakov2022_test" | ||
| leftout_file_out="encoded_dataset/Manakov2022_flat/AGO2_eCLIP_Manakov2022_leftout" | ||
|
Comment on lines
+8
to
+14
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might be more flexible to have these parameters ads input arguments for this script |
||
|
|
||
| train_file_size=2516195 | ||
|
|
||
| CODE="../../code/machine_learning" | ||
|
|
||
| mkdir -p encoded_dataset/Manakov2022_flat | ||
|
|
||
| # encode datasets | ||
| python $CODE/encode/binding_2D_matrix_encoder.py --i_file $train_file_in --o_prefix $train_file_out | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the script |
||
| python $CODE/encode/binding_2D_matrix_encoder.py --i_file $test_file_in --o_prefix $test_file_out | ||
| python $CODE/encode/binding_2D_matrix_encoder.py --i_file $leftout_file_in --o_prefix $leftout_file_out | ||
|
|
||
| # run hyper parameter optimisation | ||
| python hyperparam_optimization/hyperparam_optimization.py \ | ||
| --dataset-train "../${train_file_out}_dataset.npy" \ | ||
| --labels-train "../${train_file_out}_labels.npy" \ | ||
| --dataset-size $train_file_size \ | ||
| --dataset-ratio 1 \ | ||
| --batch-size 32 \ | ||
| --validation-split 0.1 \ | ||
| --n-trials 20 \ | ||
| --best-model $best_model_path \ | ||
| --log-file "hyperparam_optimization.log" | ||
| --seed 42 | ||
| --epochs 5 | ||
|
|
||
| # evaluate the best model | ||
| python $CODE/evaluate/evaluate_model.py \ | ||
| --model-path $best_model_path \ | ||
| --dataset-test "../${test_file_out}_dataset.npy" \ | ||
| --labels-test "../${test_file_out}_labels.npy" \ | ||
| --batch-size 32 \ | ||
| --log-file "model_evaluation_test.log" \ | ||
| --save-plots \ | ||
| --output-dir $evaluation_out_dir | ||
|
|
||
| python $CODE/evaluate/evaluate_model.py \ | ||
| --model-path $best_model_path \ | ||
| --dataset-test "../${leftout_file_out}_dataset.npy" \ | ||
| --labels-test "../${leftout_file_out}_labels.npy" \ | ||
| --batch-size 32 \ | ||
| --log-file "model_evaluation_leftout.log" \ | ||
| --save-plots \ | ||
| --output-dir $evaluation_out_dir | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This script seems not to be used. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| #!/bin/bash | ||
|
|
||
|
|
||
| TEST_DATASET="../../data/chimeric_datasets/Manakov2022_flat/AGO2_eCLIP_Manakov2022_test.tsv" | ||
| LEFTOUT_DATASET="../../data/chimeric_datasets/Manakov2022_flat/AGO2_eCLIP_Manakov2022_leftout.tsv" | ||
| TRAIN_DATASET="../../data/chimeric_datasets/Manakov2022_flat/AGO2_eCLIP_Manakov2022_train.tsv" | ||
|
|
||
|
|
||
| TEST_DATASET_OUT="encoded_dataset/Manakov2022_flat/AGO2_eCLIP_Manakov2022_1_test" | ||
| LEFTOUT_DATASET_OUT="encoded_dataset/Manakov2022_flat/AGO2_eCLIP_Manakov2022_1_leftout" | ||
| TRAIN_DATASET_OUT="encoded_dataset/Manakov2022_flat/AGO2_eCLIP_Manakov2022_1_train" | ||
|
|
||
|
|
||
| CODE="../../code/machine_learning" | ||
|
|
||
| mkdir -p encoded_dataset/Manakov2022_flat | ||
|
|
||
| # encode datasets | ||
| python $CODE/encode/binding_2D_matrix_encoder.py --i_file $TEST_DATASET --o_prefix $TEST_DATASET_OUT | ||
| python $CODE/encode/binding_2D_matrix_encoder.py --i_file $LEFTOUT_DATASET --o_prefix $LEFTOUT_DATASET_OUT | ||
| python $CODE/encode/binding_2D_matrix_encoder.py --i_file $TRAIN_DATASET --o_prefix $TRAIN_DATASET_OUT |
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better describe what are outputs, inputs and itermediate files. Also, please include example how to run this script |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| # miRBind CNN retraining with original parameters | ||
|
|
||
| Run | ||
| ```bash run_retraining.sh``` | ||
| to retrain the miRBind CNN as presented in the [miRBind paper](https://doi.org/10.3390/genes13122323) on Manakov 1:1 train dataset. | ||
| The training is done with the original hyperparameters used in the paper. | ||
|
|
||
| ### Dependencies | ||
|
|
||
| - python=3.8 | ||
| - tensorflow=2.13 | ||
| - matplotlib | ||
| - numpy | ||
| - pandas | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the reason to have some scripts inside hyperparam_optimization and some not?