diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7043de7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +.ipynb_checkpoints +.keras +*.keras diff --git a/analysis/miRBind_CNN_retraining_optimized/hyperparam_optimization/README.md b/analysis/miRBind_CNN_retraining_optimized/hyperparam_optimization/README.md new file mode 100644 index 0000000..7e0c7a8 --- /dev/null +++ b/analysis/miRBind_CNN_retraining_optimized/hyperparam_optimization/README.md @@ -0,0 +1,21 @@ +# mirBind Model optimisation pipeline + +This repository contains scripts for training and evaluating a deep learning model based on variations and tuning of the miRBind architecture for miRNA-binding prediction using eCLIP data from Manakov + +1. ../encode_dataset.sh +Converts AGO2 eCLIP datasets from Manakov2022 into the 2D matrix format. +Before running the script, data needs to be place in this directory (same for test and leftout dataset): "miRBind_2.0/data/chimeric_datasets/Manakov2022_flat/AGO2_eCLIP_Manakov2022_train.tsv" + +2. hyperparam_optimization.sh +Performs hyperparameter optimization for the model. Saves the best model checkpoint, architecture description, training stats, and metrics. + +3. train_model.sh +Trains the model using the optimized hyperparameters until convergence. Saves model checkpoints and training results. +Requires setting the name (timestamp) for your model + +4. evaluate_model.sh +Evaluates the trained model on test and left-out datasets. Requires setting the name (timestamp) of your trained model. Generates performance metrics and plots. +Saves results. + +../hyperparam_optimization_pipeline.sh +Orchestrates the (almost) entire workflow (except training until convergence with found hyperpara.) in a single execution, combining dataset encoding, hyperparameter optimization, and model evaluation. \ No newline at end of file diff --git a/analysis/miRBind_CNN_retraining_optimized/hyperparam_optimization/evaluate_model.sh b/analysis/miRBind_CNN_retraining_optimized/hyperparam_optimization/evaluate_model.sh new file mode 100755 index 0000000..7d8438a --- /dev/null +++ b/analysis/miRBind_CNN_retraining_optimized/hyperparam_optimization/evaluate_model.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +test_file_out="encoded_dataset/Manakov2022_flat/AGO2_eCLIP_Manakov2022_test" +leftout_file_out="encoded_dataset/Manakov2022_flat/AGO2_eCLIP_Manakov2022_leftout" + +# set the model_name to how you named your run +timestamp=#TODO_SET_YOUR_TRAINED_MODEL'S_TIMESTAMP +model_name="mirBind_${timestamp}" +best_model_path="evaluation_results/${model_name}/${model_name}_final.keras" +evaluation_out_dir="evaluation_results/${model_name}" + +python ../../../code/machine_learning/evaluate/evaluate_model.py \ + --model-path $best_model_path \ + --dataset-test "../${test_file_out}_dataset.npy" \ + --labels-test "../${test_file_out}_labels.npy" \ + --batch-size 32 \ + --log-file "model_evaluation_test.log" \ + --save-plots \ + --output-dir $evaluation_out_dir + +python ../../../code/machine_learning/evaluate/evaluate_model.py \ + --model-path $best_model_path \ + --dataset-test "../${leftout_file_out}_dataset.npy" \ + --labels-test "../${leftout_file_out}_labels.npy" \ + --batch-size 32 \ + --log-file "model_evaluation_leftout.log" \ + --save-plots \ + --output-dir $evaluation_out_dir \ No newline at end of file diff --git a/analysis/miRBind_CNN_retraining_optimized/hyperparam_optimization/hyperparam_optimization.py b/analysis/miRBind_CNN_retraining_optimized/hyperparam_optimization/hyperparam_optimization.py new file mode 100644 index 0000000..f6bd4bd --- /dev/null +++ b/analysis/miRBind_CNN_retraining_optimized/hyperparam_optimization/hyperparam_optimization.py @@ -0,0 +1,158 @@ +import argparse +import numpy as np +import logging +import optuna +import optuna.visualization as vis +from optuna.integration import TFKerasPruningCallback +from tensorflow import keras as K +from tensorflow.keras.optimizers import Adam +from tensorflow.keras.utils import Sequence +import tensorflow as tf +import random + +from code.machine_learning.utils import set_seeds, compile_model +from code.machine_learning.data_generators import TrainDataGenerator +import sys +sys.path.append("../../../code/machine_learning/train/CNN_miRBind_2022/") +from miRBind_CNN_architecture import miRBind_CNN + + +def objective(trial, train_data_gen, val_data_gen, dataset_ratio, best_model_path, epochs): + global best_model, best_val_auc + + K.backend.clear_session() + + cnn_num = trial.suggest_int('cnn_layers_num', 2, 10) + kernel_size = trial.suggest_int('kernel_size', 3, 10) + pool_size = trial.suggest_int('pool_size', 1, 8) + dense_num = trial.suggest_int('dense_layers_num', 2, cnn_num) + model = miRBind_CNN(cnn_num=cnn_num, kernel_size=kernel_size, pool_size=pool_size, dense_num=dense_num).model + lr = trial.suggest_float('learning_rate', 0.00001, 0.0001) + model = compile_model(model, lr=lr) + + model_history = model.fit( + train_data_gen, + validation_data=val_data_gen, + epochs=epochs, + class_weight={0: 1, 1: dataset_ratio}, + callbacks=[ + TFKerasPruningCallback(trial, "val_auc"), + K.callbacks.EarlyStopping(patience=3, restore_best_weights=True) + ], + ) + + num_epochs_trained = np.argmax(model_history.history['val_auc']) + val_auc = model_history.history['val_auc'][num_epochs_trained] + + if val_auc > best_val_auc: + best_val_auc = val_auc + best_model = model + model.save(best_model_path) + logger.info(f"New best model found and saved with Validation AUC: {val_auc}") + + print(f"Validation AU PRC: {val_auc}") + return val_auc + + +def setup_logger(log_file): + logger = logging.getLogger('optuna') + logger.setLevel(logging.INFO) + file_handler = logging.FileHandler(log_file, 'w') + file_handler.setFormatter(logging.Formatter('%(asctime)s - %(message)s')) + logger.addHandler(file_handler) + return logger + + +def main(): + parser = argparse.ArgumentParser(description='Hyperparameter optimization for miRBind CNN model') + parser.add_argument('--dataset-train', type=str, + default='../encoded_dataset/Manakov2022_flat/AGO2_eCLIP_Manakov2022_train_dataset.npy', + help='Path to the train dataset') + parser.add_argument('--labels-train', type=str, + default='../encoded_dataset/Manakov2022_flat/AGO2_eCLIP_Manakov2022_train_labels.npy', + help='Path to the train labels') + parser.add_argument('--dataset-size', type=int, default=2516195, + help='Size of the dataset') + parser.add_argument('--dataset-ratio', type=float, default=1, + help='Dataset ratio for class weighting') + parser.add_argument('--batch-size', type=int, default=32, + help='Batch size for training') + parser.add_argument('--validation-split', type=float, default=0.1, + help='Validation split ratio') + parser.add_argument('--n-trials', type=int, default=20, + help='Number of optimization trials') + parser.add_argument('--epochs', type=int, default=5, + help='Number of max epochs per model') + parser.add_argument('--best-model', type=str, default='best_model.log', + help='Path to the model trained with optimised hyperparameters') + parser.add_argument('--log-file', type=str, default='hyperparam_optimization.log', + help='Path to the log file') + parser.add_argument('--save-plots', action='store_true', + help='Save optimization plots') + parser.add_argument('--seed', type=int, default=42, + help='Random seed for reproducibility') + args = parser.parse_args() + + # Set seeds for reproducibility + set_seeds(args.seed) + + global logger, best_model, best_val_auc + logger = setup_logger(args.log_file) + best_model = None + best_val_auc = 0 + + logger.info(f"Starting optimization with seed: {args.seed}") + + train_data_gen = TrainDataGenerator( + args.dataset_train, + args.labels_train, + dataset_size=args.dataset_size, + batch_size=args.batch_size, + validation_split=args.validation_split, + is_validation=False + ) + + val_data_gen = TrainDataGenerator( + args.dataset_train, + args.labels_train, + dataset_size=args.dataset_size, + batch_size=args.batch_size, + validation_split=args.validation_split, + is_validation=True + ) + + # Set seed for Optuna + optuna_sampler = optuna.samplers.TPESampler(seed=args.seed) + study = optuna.create_study( + direction='maximize', + study_name='miRBind_CNN', + sampler=optuna_sampler + ) + + study.optimize( + lambda trial: objective(trial, train_data_gen, val_data_gen, args.dataset_ratio, args.best_model, args.epochs), + n_trials=args.n_trials + ) + + logger.info("\n") + logger.info(f"Best hyperparameters: {study.best_params}") + logger.info(f"Best value (validation AU PRC): {study.best_value}") + + if args.save_plots: + plots = { + 'optimization_history': vis.plot_optimization_history, + 'contour': vis.plot_contour, + 'param_importances': vis.plot_param_importances, + 'slice': vis.plot_slice + } + + for name, plot_func in plots.items(): + try: + fig = plot_func(study) + fig.write_image(f"{name}.png") + except Exception as e: + logger.error(f"Failed to save {name} plot: {str(e)}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/analysis/miRBind_CNN_retraining_optimized/hyperparam_optimization/hyperparam_optimization.sh b/analysis/miRBind_CNN_retraining_optimized/hyperparam_optimization/hyperparam_optimization.sh new file mode 100755 index 0000000..1948468 --- /dev/null +++ b/analysis/miRBind_CNN_retraining_optimized/hyperparam_optimization/hyperparam_optimization.sh @@ -0,0 +1,24 @@ +#!/bin/bash + + +train_file="encoded_dataset/Manakov2022_flat/AGO2_eCLIP_Manakov2022_train" + +train_file_size=2516195 +model_dir="mirBind_2002" +best_model_path="models/${model_dir}/best_model.keras" +evaluation_out_dir="evaluation_results/${model_dir}" + + +# run hyper parameter optimisation +python hyperparam_optimization.py \ + --dataset-train "../${train_file}_dataset.npy" \ + --labels-train "../${train_file}_labels.npy" \ + --dataset-size $train_file_size \ + --dataset-ratio 1 \ + --batch-size 32 \ + --validation-split 0.1 \ + --n-trials 30 \ + --best-model "$best_model_path" \ + --log-file "${evaluation_out_dir}/hyperparam_optimization.log" + --seed 42 + --epochs 8 \ No newline at end of file diff --git a/analysis/miRBind_CNN_retraining_optimized/hyperparam_optimization/train_model.sh b/analysis/miRBind_CNN_retraining_optimized/hyperparam_optimization/train_model.sh new file mode 100755 index 0000000..b8d7eac --- /dev/null +++ b/analysis/miRBind_CNN_retraining_optimized/hyperparam_optimization/train_model.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +# set the model_name to a unique name for your run +timestamp=$(date +"%Y%m%d_%H%M%S") +model_name="mirBind_${timestamp}" +out_dir="evaluation_results/${model_name}" + +train_file="encoded_dataset/Manakov2022_flat/AGO2_eCLIP_Manakov2022_train" +train_file_size=2516195 + +CODE="../../code/machine_learning" + +python $CODE/train/train_model.py \ + --dataset-train "../${train_file}_dataset.npy" \ + --labels-train "../${train_file}_labels.npy" \ + --dataset-size $train_file_size \ + --cnn-num 2 \ + --kernel-size 6 \ + --pool-size 2 \ + --dropout-rate 0.3 \ + --dense-num 2 \ + --learning-rate 0.00008241877487855944 \ + --batch-size 32 \ + --epochs 100 \ + --patience 6 \ + --output-dir $out_dir \ + --model-name $model_name \ + --seed 42 + \ No newline at end of file diff --git a/analysis/miRBind_CNN_retraining_optimized/hyperparam_optimization_pipeline.sh b/analysis/miRBind_CNN_retraining_optimized/hyperparam_optimization_pipeline.sh new file mode 100644 index 0000000..0320f60 --- /dev/null +++ b/analysis/miRBind_CNN_retraining_optimized/hyperparam_optimization_pipeline.sh @@ -0,0 +1,58 @@ +#!/bin/bash + +timestamp=$(date +"%Y%m%d_%H%M%S") +model_name="mirBind_${timestamp}" +best_model_path="models/${model_name}.keras" +evaluation_out_dir="evaluation_results/${model_name}_hyperopt" + +train_file_in="../../data/chimeric_datasets/Manakov2022_flat/AGO2_eCLIP_Manakov2022_train.tsv" +test_file_in="../../data/chimeric_datasets/Manakov2022_flat/AGO2_eCLIP_Manakov2022_test.tsv" +leftout_file_in="../../data/chimeric_datasets/Manakov2022_flat/AGO2_eCLIP_Manakov2022_leftout.tsv" + +train_file_out="encoded_dataset/Manakov2022_flat/AGO2_eCLIP_Manakov2022_train" +test_file_out="encoded_dataset/Manakov2022_flat/AGO2_eCLIP_Manakov2022_test" +leftout_file_out="encoded_dataset/Manakov2022_flat/AGO2_eCLIP_Manakov2022_leftout" + +train_file_size=2516195 + +CODE="../../code/machine_learning" + +mkdir -p encoded_dataset/Manakov2022_flat + +# encode datasets +python $CODE/encode/binding_2D_matrix_encoder.py --i_file $train_file_in --o_prefix $train_file_out +python $CODE/encode/binding_2D_matrix_encoder.py --i_file $test_file_in --o_prefix $test_file_out +python $CODE/encode/binding_2D_matrix_encoder.py --i_file $leftout_file_in --o_prefix $leftout_file_out + +# run hyper parameter optimisation +python hyperparam_optimization/hyperparam_optimization.py \ + --dataset-train "../${train_file_out}_dataset.npy" \ + --labels-train "../${train_file_out}_labels.npy" \ + --dataset-size $train_file_size \ + --dataset-ratio 1 \ + --batch-size 32 \ + --validation-split 0.1 \ + --n-trials 20 \ + --best-model $best_model_path \ + --log-file "hyperparam_optimization.log" + --seed 42 + --epochs 5 + +# evaluate the best model +python $CODE/evaluate/evaluate_model.py \ + --model-path $best_model_path \ + --dataset-test "../${test_file_out}_dataset.npy" \ + --labels-test "../${test_file_out}_labels.npy" \ + --batch-size 32 \ + --log-file "model_evaluation_test.log" \ + --save-plots \ + --output-dir $evaluation_out_dir + +python $CODE/evaluate/evaluate_model.py \ + --model-path $best_model_path \ + --dataset-test "../${leftout_file_out}_dataset.npy" \ + --labels-test "../${leftout_file_out}_labels.npy" \ + --batch-size 32 \ + --log-file "model_evaluation_leftout.log" \ + --save-plots \ + --output-dir $evaluation_out_dir diff --git a/analysis/miRBind_CNN_retraining_optimized/run_data_encoding.sh b/analysis/miRBind_CNN_retraining_optimized/run_data_encoding.sh new file mode 100755 index 0000000..bf780ef --- /dev/null +++ b/analysis/miRBind_CNN_retraining_optimized/run_data_encoding.sh @@ -0,0 +1,21 @@ +#!/bin/bash + + +TEST_DATASET="../../data/chimeric_datasets/Manakov2022_flat/AGO2_eCLIP_Manakov2022_test.tsv" +LEFTOUT_DATASET="../../data/chimeric_datasets/Manakov2022_flat/AGO2_eCLIP_Manakov2022_leftout.tsv" +TRAIN_DATASET="../../data/chimeric_datasets/Manakov2022_flat/AGO2_eCLIP_Manakov2022_train.tsv" + + +TEST_DATASET_OUT="encoded_dataset/Manakov2022_flat/AGO2_eCLIP_Manakov2022_1_test" +LEFTOUT_DATASET_OUT="encoded_dataset/Manakov2022_flat/AGO2_eCLIP_Manakov2022_1_leftout" +TRAIN_DATASET_OUT="encoded_dataset/Manakov2022_flat/AGO2_eCLIP_Manakov2022_1_train" + + +CODE="../../code/machine_learning" + +mkdir -p encoded_dataset/Manakov2022_flat + +# encode datasets +python $CODE/encode/binding_2D_matrix_encoder.py --i_file $TEST_DATASET --o_prefix $TEST_DATASET_OUT +python $CODE/encode/binding_2D_matrix_encoder.py --i_file $LEFTOUT_DATASET --o_prefix $LEFTOUT_DATASET_OUT +python $CODE/encode/binding_2D_matrix_encoder.py --i_file $TRAIN_DATASET --o_prefix $TRAIN_DATASET_OUT diff --git a/analysis/miRBind_CNN_retraining_orig_parameters/README.md b/analysis/miRBind_CNN_retraining_orig_parameters/README.md new file mode 100644 index 0000000..de651fb --- /dev/null +++ b/analysis/miRBind_CNN_retraining_orig_parameters/README.md @@ -0,0 +1,15 @@ +# miRBind CNN retraining with original parameters + +Run +```bash run_retraining.sh``` +to retrain the miRBind CNN as presented in the [miRBind paper](https://doi.org/10.3390/genes13122323) on Manakov 1:1 train dataset. +The training is done with the original hyperparameters used in the paper. + +### Dependencies + +- python=3.8 +- tensorflow=2.13 +- matplotlib +- numpy +- pandas + diff --git a/analysis/miRBind_CNN_retraining_orig_parameters/run_retraining.sh b/analysis/miRBind_CNN_retraining_orig_parameters/run_retraining.sh new file mode 100755 index 0000000..48ec8ee --- /dev/null +++ b/analysis/miRBind_CNN_retraining_orig_parameters/run_retraining.sh @@ -0,0 +1,100 @@ +#!/bin/bash + +# Function to display usage instructions +usage() { + echo "Usage: $0 -t -l -r -m [-c ]" + echo " -t: Test dataset TSV file (required)" + echo " -l: Leftout dataset TSV file (required)" + echo " -r: Train dataset TSV file (required)" + echo " -m: Model path (required)" + echo " -c: Code path (optional, default: ../../code/machine_learning)" + exit 1 +} + +# Default code path +CODE="../../code/machine_learning" + +# Parse command-line arguments +while getopts ":t:l:r:m:c:" opt; do + case ${opt} in + t ) + TEST_DATASET=$OPTARG + ;; + l ) + LEFTOUT_DATASET=$OPTARG + ;; + r ) + TRAIN_DATASET=$OPTARG + ;; + m ) + MODEL=$OPTARG + ;; + c ) + CODE=$OPTARG + ;; + \? ) + echo "Invalid option: $OPTARG" 1>&2 + usage + ;; + : ) + echo "Invalid option: $OPTARG requires an argument" 1>&2 + usage + ;; + esac +done +shift $((OPTIND -1)) + +# Validate required arguments +if [ -z "$TEST_DATASET" ] || [ -z "$LEFTOUT_DATASET" ] || [ -z "$TRAIN_DATASET" ] || [ -z "$MODEL" ]; then + echo "Error: Missing required arguments" 1>&2 + usage +fi + +# Generate output prefixes based on input file names +TEST_DATASET_OUT="encoded_dataset/$(basename "$(dirname "$TEST_DATASET")")/$(basename "$TEST_DATASET" .tsv)" +LEFTOUT_DATASET_OUT="encoded_dataset/$(basename "$(dirname "$LEFTOUT_DATASET")")/$(basename "$LEFTOUT_DATASET" .tsv)" +TRAIN_DATASET_OUT="encoded_dataset/$(basename "$(dirname "$TRAIN_DATASET")")/$(basename "$TRAIN_DATASET" .tsv)" + +# Create output directory +mkdir -p "$(dirname "$TEST_DATASET_OUT")" + +# Function to check and encode dataset +encode_dataset() { + local input_file=$1 + local output_prefix=$2 + + # Check if the dataset and labels .npy files already exist + if [ ! -f "${output_prefix}_dataset.npy" ] || [ ! -f "${output_prefix}_labels.npy" ]; then + echo "Encoding dataset: $input_file" + python "$CODE/encode/binding_2D_matrix_encoder.py" --i_file "$input_file" --o_prefix "$output_prefix" + else + echo "Encoded files for $input_file already exist. Skipping encoding." + fi +} + +# Encode datasets +encode_dataset "$TEST_DATASET" "$TEST_DATASET_OUT" +encode_dataset "$LEFTOUT_DATASET" "$LEFTOUT_DATASET_OUT" +encode_dataset "$TRAIN_DATASET" "$TRAIN_DATASET_OUT" + +# Train model (check if training dataset files exist) +TRAIN_DATASET_NPY="${TRAIN_DATASET_OUT}_dataset.npy" +TRAIN_LABELS_NPY="${TRAIN_DATASET_OUT}_labels.npy" + +# Determine dataset size (can be modified if needed) +DATASET_SIZE=$(wc -l < "$TRAIN_DATASET") + +if [ -f "$TRAIN_DATASET_NPY" ] && [ -f "$TRAIN_LABELS_NPY" ]; then + echo "Training model..." + python "$CODE/train/CNN_miRBind_2022/miRBind_CNN_training_orig_parameters.py" \ + --data "$TRAIN_DATASET_NPY" \ + --labels "$TRAIN_LABELS_NPY" \ + --dataset_size "$DATASET_SIZE" \ + --ratio 1 \ + --model "$MODEL" +else + echo "Error: Training dataset or labels file not found. Cannot proceed with training." + exit 1 +fi + +echo "Process completed." \ No newline at end of file diff --git a/code/machine_learning/data_generators.py b/code/machine_learning/data_generators.py new file mode 100644 index 0000000..c94a531 --- /dev/null +++ b/code/machine_learning/data_generators.py @@ -0,0 +1,86 @@ +import os +import numpy as np +from tensorflow.keras.utils import Sequence + + +class TrainDataGenerator(Sequence): + def __init__(self, data_path, labels_path, dataset_size, batch_size=32, validation_split=0.1, is_validation=False, shuffle=True): + self.size = dataset_size + self.data = np.memmap(data_path, dtype='float32', mode='r', shape=(self.size, 50, 20, 1)) + self.labels = np.memmap(labels_path, dtype='float32', mode='r', shape=(self.size,)) + self.batch_size = batch_size + self.shuffle = shuffle + self.validation_split = validation_split + self.num_samples = len(self.data) + self.num_validation_samples = int(self.num_samples * validation_split) + self.num_train_samples = self.num_samples - self.num_validation_samples + + indices = np.arange(self.num_samples) + if shuffle: + np.random.shuffle(indices) + + if is_validation: + self.indices = indices[self.num_train_samples:] + else: + self.indices = indices[:self.num_train_samples] + + self.on_epoch_end() + + def __len__(self): + return int(np.ceil(len(self.indices) / float(self.batch_size))) + + def __getitem__(self, idx): + batch_indices = self.indices[idx * self.batch_size:(idx + 1) * self.batch_size] + batch_data = self.data[batch_indices] + batch_labels = self.labels[batch_indices] + return batch_data, batch_labels + + def on_epoch_end(self): + if self.shuffle: + np.random.shuffle(self.indices) + + +class TestDataGenerator: + def __init__(self, data_path, labels_path, batch_size=32, dataset_size=None): + if dataset_size is None: + # Try to determine the dataset size by checking file properties + try: + # try to load just the header to get shape and dtype + with open(data_path, 'rb') as f: + if f.read(6) == b'\x93NUMPY': + # This is a standard numpy file, we can get info from header + f.seek(0) + version = np.lib.format.read_magic(f) + shape_dict = np.lib.format.read_array_header_1_0(f) if version == (1, 0) else np.lib.format.read_array_header_2_0(f) + shape = shape_dict[0] + dataset_size = shape[0] + else: + # Not a standard numpy file, we'll try other methods + raise ValueError("Not a standard numpy file") + except: + # try to infer from file size + # This assumes the files are memory-mapped in a specific format + # For dataset: shape=(n, 50, 20, 1), dtype=float32 (4 bytes) + # For labels: shape=(n,), dtype=float32 (4 bytes) + data_size_bytes = os.path.getsize(data_path) + labels_size_bytes = os.path.getsize(labels_path) + + # Calculate dataset_size based on assumed structure + dataset_size_from_data = data_size_bytes // (4 * 50 * 20) + dataset_size_from_labels = labels_size_bytes // 4 + + # Verify sizes match approximately + if abs(dataset_size_from_data - dataset_size_from_labels) < 10: + dataset_size = dataset_size_from_data + else: + raise ValueError(f"Inconsistent file sizes: data suggests {dataset_size_from_data} samples, labels suggests {dataset_size_from_labels}") + + # Create memory-mapped arrays with the determined size + self.data = np.memmap(data_path, dtype='float32', mode='r', shape=(dataset_size, 50, 20, 1)) + self.labels = np.memmap(labels_path, dtype='float32', mode='r', shape=(dataset_size,)) + self.batch_size = batch_size + self.num_samples = dataset_size + + def get_data(self): + """Return all test data and labels""" + return self.data, self.labels \ No newline at end of file diff --git a/code/machine_learning/encode/README.md b/code/machine_learning/encode/README.md index 79fe152..697c03e 100644 --- a/code/machine_learning/encode/README.md +++ b/code/machine_learning/encode/README.md @@ -1 +1,17 @@ -# Encoding the dataset into inner representation \ No newline at end of file +# Encoding the dataset into inner representation + +### [Binding 2D matrix encoder](binding_2d_matrix_encoder.py) + The encoder is based on the "miRBind: A deep learning method for miRNA binding classification." (2022) https://doi.org/10.3390/genes13122323 + with original python implementation here: https://github.com/ML-Bioinfo-CEITEC/miRBind + +Encodes miRNA and gene sequences into 2D-binding matrix. +2D-binding matrix has shape (gene_max_len=50, miRNA_max_len=20, 1) and contains 1 for Watson-Crick interactions and 0 otherwise. + +Outputs npy file with encoded matrices and npy file with corresponding labels. + +#### Usage +Run the script from the command line with the following syntax: + + +```python binding_2d_matrix_encoder.py --i_file input_dataset_file.tsv --o_prefix output_prefix``` + diff --git a/code/machine_learning/encode/binding_2D_matrix_encoder.py b/code/machine_learning/encode/binding_2D_matrix_encoder.py index 70b6360..484e534 100644 --- a/code/machine_learning/encode/binding_2D_matrix_encoder.py +++ b/code/machine_learning/encode/binding_2D_matrix_encoder.py @@ -1,37 +1,102 @@ -class miRBindEncoder(): +import pandas as pd +import numpy as np +import argparse +import time + + +def binding_encoding(df, alphabet, tensor_dim=(50, 20, 1)): + """ + Transform input sequence pairs to a binding matrix with corresponding labels. + + Parameters: + - df: Pandas DataFrame with columns "noncodingRNA", "gene", "label" + - alphabet: dictionary with letter tuples as keys and 1s when they bind + - tensor_dim: 2D binding matrix shape + + Output: + 2D binding matrix, labels as np array + """ + labels = df["label"].to_numpy() + + # Initialize dot matrix with zeros + ohe_matrix_2d = np.zeros((len(df), *tensor_dim), dtype="float32") + + df = df.reset_index(drop=True) + + # Compile matrix with Watson-Crick interactions + for index, row in df.iterrows(): + for bind_index, bind_nt in enumerate(row['gene'].upper()): + for ncrna_index, ncrna_nt in enumerate(row['noncodingRNA'].upper()): + if ncrna_index >= tensor_dim[1]: + break + base_pairs = bind_nt + ncrna_nt + ohe_matrix_2d[index, bind_index, ncrna_index, 0] = alphabet.get(base_pairs, 0) + + return ohe_matrix_2d, labels + + +def encode_large_tsv_to_numpy(tsv_file_path, data_output_path, labels_output_path, chunk_size=10000): + """ + Encode a large TSV file into a NumPy matrix using chunk processing. + + Parameters: + - tsv_file_path: Path to the TSV file with dataset. + - data_output_path: Path to the output data .npy file. + - labels_output_path: Path to the output labels .npy file. + - chunk_size: Number of rows to process at a time. """ - Based on Klimentová, Eva, et al. "miRBind: A deep learning method for miRNA binding classification." Genes 13.12 (2022): 2323. https://doi.org/10.3390/genes13122323. - Python implementation: https://github.com/ML-Bioinfo-CEITEC/miRBind + # Alphabet for Watson-Crick interactions + alphabet = {"AT": 1., "TA": 1., "GC": 1., "CG": 1.} + tensor_dim = (50, 20, 1) + + # Get total number of rows in the dataset + num_rows = sum(len(df) for df in pd.read_csv(tsv_file_path, sep='\t', usecols=[0], chunksize=chunk_size)) + + # Determine the shape of the output arrays + labels_shape = (num_rows,) + data_shape = (num_rows, *tensor_dim) + + # Create memory-mapped files + ohe_matrix_2d = np.memmap(data_output_path, dtype='float32', mode='w+', shape=data_shape) + labels = np.memmap(labels_output_path, dtype='float32', mode='w+', shape=labels_shape) + + row_offset = 0 + + # Process each chunk + for chunk in pd.read_csv(tsv_file_path, sep='\t', chunksize=chunk_size): + encoded_data, encoded_labels = binding_encoding(chunk, alphabet, tensor_dim) + + # Write the chunk's data and labels to the memory-mapped files + ohe_matrix_2d[row_offset:row_offset + len(chunk)] = encoded_data + labels[row_offset:row_offset + len(chunk)] = encoded_labels + row_offset += len(chunk) + + # Flush changes to disk + ohe_matrix_2d.flush() + labels.flush() + + +def main(): + """ + Based on "miRBind: A deep learning method for miRNA binding classification." Genes 13.12 (2022): 2323. https://doi.org/10.3390/genes13122323. + Original implementation: https://github.com/ML-Bioinfo-CEITEC/miRBind Encodes miRNA and gene sequences into 2D-binding matrix. - 2D-binding matrix has shape (gene_max_len, miRNA_max_len, 1) and contains 1 for Watson-Crick interactions and 0 otherwise. - Returns array with shape (num_of_samples, gene_max_len, miRNA_max_len, 1). + 2D-binding matrix has shape (gene_max_len=50, miRNA_max_len=20, 1) and contains 1 for Watson-Crick interactions and 0 otherwise. """ - def __call__(self, df, miRNA_col="noncodingRNA", gene_col="gene", tensor_dim=(50, 20, 1)): - return self.binding_encoding(df, miRNA_col, gene_col, tensor_dim) - - def binding_encoding(self, df, miRNA_col, gene_col, tensor_dim): - """ - fun encodes miRNAs and mRNAs in df into binding matrices - :param df: dataframe containing gene_col and miRNA_col columns - :param tensor_dim: output shape of the matrix. If sequences are longer than tensor_dim, they will be truncated. - :return: 2D binding matrix with shape (N, *tensor_dim) - """ - - # alphabet for watson-crick interactions. - alphabet = {"AT": 1., "TA": 1., "GC": 1., "CG": 1., "AU": 1., "UA": 1.} - # create empty main 2d matrix array - N = df.shape[0] # number of samples in df - shape_matrix_2d = (N, *tensor_dim) # 2d matrix shape - # initialize dot matrix with zeros - ohe_matrix_2d = np.zeros(shape_matrix_2d, dtype="float32") - - # compile matrix with watson-crick interactions. - for index, row in df.iterrows(): - for bind_index, bind_nt in enumerate(row[gene_col][:tensor_dim[0]].upper()): - for mirna_index, mirna_nt in enumerate(row[miRNA_col][:tensor_dim[1]].upper()): - base_pairs = bind_nt + mirna_nt - ohe_matrix_2d[index, bind_index, mirna_index, 0] = alphabet.get(base_pairs, 0) - - return ohe_matrix_2d \ No newline at end of file + parser = argparse.ArgumentParser( + description="Encode dataset to miRNA x target binding matrix. Outputs numpy file with matrices and and numpy file with corresponding labels. Expected columns of the dataset are 'noncodingRNA', 'gene' and 'label'") + parser.add_argument('-i', '--i_file', type=str, required=True, help="Input dataset file name") + parser.add_argument('-o', '--o_prefix', type=str, required=True, help="Output file name prefix") + args = parser.parse_args() + + start = time.time() + encode_large_tsv_to_numpy(args.i_file, args.o_prefix + '_dataset.npy', args.o_prefix + '_labels.npy') + end = time.time() + + print("Elapsed time: ", end - start, " s.") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/code/machine_learning/evaluate/evaluate_model.py b/code/machine_learning/evaluate/evaluate_model.py new file mode 100644 index 0000000..19799c4 --- /dev/null +++ b/code/machine_learning/evaluate/evaluate_model.py @@ -0,0 +1,122 @@ +import argparse +import os +import logging +import numpy as np +from tensorflow import keras as K +from sklearn.metrics import precision_recall_curve, auc, roc_curve, roc_auc_score, accuracy_score, average_precision_score + +from code.machine_learning.data_generators import TestDataGenerator +from plots import plot_roc_curve, plot_pr_curve +from utils import setup_logger + + +def evaluate_model(model, test_data, test_labels, logger, save_plots=True, output_dir='.', pred_threshold=0.5): + """Evaluate model performance""" + # Get predictions from prediction probabilities + y_pred_proba = model.predict(test_data) + y_pred = (y_pred_proba > pred_threshold).astype(int) + + # Calculate metrics + accuracy = accuracy_score(test_labels, y_pred) + + fpr, tpr, _ = roc_curve(test_labels, y_pred_proba) + roc_auc = roc_auc_score(test_labels, y_pred_proba) + + precision, recall, _ = precision_recall_curve(test_labels, y_pred_proba) + pr_auc = auc(recall, precision) + + avg_precision = average_precision_score(test_labels, y_pred_proba) + + logger.info(f"Model Evaluation Results:") + logger.info(f"Accuracy: {accuracy:.4f}") + logger.info(f"ROC AUC: {roc_auc:.4f}") + logger.info(f"PR AUC: {pr_auc:.4f}") + logger.info(f"Average Precision: {avg_precision:.4f}") + + if save_plots: + os.makedirs(output_dir, exist_ok=True) + + plot_roc_curve(fpr, tpr, roc_auc, output_dir, logger, fig_save_name='roc_curve.png') + plot_pr_curve(recall, precision, pr_auc, avg_precision, output_dir, logger, fig_save_name='pr_curve.png') + + return { + 'accuracy': accuracy, + 'roc_auc': roc_auc, + 'pr_auc': pr_auc, + 'avg_precision': avg_precision, + 'fpr': fpr, + 'tpr': tpr, + 'precision': precision, + 'recall': recall + } + + +def main(): + parser = argparse.ArgumentParser(description='Evaluate trained miRBind CNN model') + parser.add_argument('--model-path', type=str, default='best_model.keras', + help='Path to the trained model file') + parser.add_argument('--dataset-test', type=str, + default='../encoded_dataset/Manakov2022_flat/AGO2_eCLIP_Manakov2022_test_dataset.npy', + help='Path to the test dataset') + parser.add_argument('--labels-test', type=str, + default='../encoded_dataset/Manakov2022_flat/AGO2_eCLIP_Manakov2022_test_labels.npy', + help='Path to the test labels') + parser.add_argument('--dataset-size', type=int, default=None, + help='Size of the test dataset (number of samples). If not provided, will attempt to determine automatically.') + parser.add_argument('--batch-size', type=int, default=32, + help='Batch size for evaluation') + parser.add_argument('--log-file', type=str, default='model_evaluation.log', + help='Path to the log file') + parser.add_argument('--save-plots', action='store_true', default=True, + help='Save evaluation plots') + parser.add_argument('--output-dir', type=str, default='evaluation_results', + help='Directory to save evaluation results') + args = parser.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + if not os.path.exists(args.output_dir): + raise RuntimeError(f"Failed to create output directory: {args.output_dir}") + + # Set up logger + logger = setup_logger(os.path.join(args.output_dir, args.log_file), 'model_evaluation') + logger.info("Starting model evaluation") + + try: + logger.info(f"Loading model from {args.model_path}") + model = K.models.load_model(args.model_path) + logger.info(f"Model loaded successfully") + + logger.info(f"Loading test data from {args.dataset_test}") + + test_data_generator = TestDataGenerator( + args.dataset_test, + args.labels_test, + batch_size=args.batch_size, + dataset_size=args.dataset_size + ) + test_data, test_labels = test_data_generator.get_data() + logger.info(f"Dataset size: {len(test_data)} samples") + + logger.info("Evaluating model performance...") + results = evaluate_model( + model, + test_data, + test_labels, + logger, + save_plots=args.save_plots, + output_dir=args.output_dir + ) + + logger.info("Model evaluation completed successfully") + + # Save model summary + with open(os.path.join(args.output_dir, 'model_summary.txt'), 'w') as f: + model.summary(print_fn=lambda x: f.write(x + '\n')) + logger.info(f"Model summary saved to {os.path.join(args.output_dir, 'model_summary.txt')}") + + except Exception as e: + logger.error(f"An error occurred during model evaluation: {str(e)}") + raise + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/code/machine_learning/plots.py b/code/machine_learning/plots.py new file mode 100644 index 0000000..eab2d32 --- /dev/null +++ b/code/machine_learning/plots.py @@ -0,0 +1,71 @@ +import os +import matplotlib.pyplot as plt + + +def plot_roc_curve(fpr, tpr, roc_auc, output_dir, logger, fig_save_name='roc_curve.png'): + save_path = os.path.join(output_dir, fig_save_name) + + plt.figure(figsize=(10, 8)) + plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.3f})') + plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--') + plt.xlim([0.0, 1.0]) + plt.ylim([0.0, 1.05]) + plt.xlabel('False Positive Rate') + plt.ylabel('True Positive Rate') + plt.title('Receiver Operating Characteristic (ROC) Curve') + plt.legend(loc="lower right") + plt.savefig(save_path) + logger.info(f"Saved ROC curve plot to {save_path}") + + +def plot_pr_curve(recall, precision, pr_auc, avg_precision, output_dir, logger, fig_save_name='pr_curve.png'): + save_path = os.path.join(output_dir, fig_save_name) + + plt.figure(figsize=(10, 8)) + plt.plot(recall, precision, color='green', lw=2, + label=f'PR curve (area = {pr_auc:.3f}, avg precision = {avg_precision:.3f})') + plt.xlabel('Recall') + plt.ylabel('Precision') + plt.ylim([0.0, 1.05]) + plt.xlim([0.0, 1.0]) + plt.title('Precision-Recall Curve') + plt.legend(loc="lower left") + plt.savefig(save_path) + logger.info(f"Saved PR curve plot to {save_path}") + + +def plot_training_history(history, output_dir): + """Plot and save training metrics.""" + # Create a figure with 3 subplots + plt.figure(figsize=(18, 5)) + + # Plot accuracy + plt.subplot(1, 3, 1) + plt.plot(history.history['accuracy']) + plt.plot(history.history['val_accuracy']) + plt.title('Model Accuracy') + plt.ylabel('Accuracy') + plt.xlabel('Epoch') + plt.legend(['Train', 'Validation'], loc='upper left') + + # Plot AUPRC + plt.subplot(1, 3, 2) + plt.plot(history.history['auc']) + plt.plot(history.history['val_auc']) + plt.title('Area Under PR Curve') + plt.ylabel('AUC') + plt.xlabel('Epoch') + plt.legend(['Train', 'Validation'], loc='upper left') + + # Plot loss + plt.subplot(1, 3, 3) + plt.plot(history.history['loss']) + plt.plot(history.history['val_loss']) + plt.title('Model Loss') + plt.ylabel('Loss') + plt.xlabel('Epoch') + plt.legend(['Train', 'Validation'], loc='upper left') + + plt.tight_layout() + plt.savefig(os.path.join(output_dir, 'training_history.png')) + plt.close() diff --git a/code/machine_learning/train/CNN_miRBind_2022/miRBind_CNN_architecture.py b/code/machine_learning/train/CNN_miRBind_2022/miRBind_CNN_architecture.py new file mode 100644 index 0000000..48d8bef --- /dev/null +++ b/code/machine_learning/train/CNN_miRBind_2022/miRBind_CNN_architecture.py @@ -0,0 +1,60 @@ +import tensorflow as tf +from tensorflow import keras as K +from tensorflow.keras.optimizers import Adam +from tensorflow.keras.layers import Input, Conv2D, LeakyReLU, BatchNormalization, MaxPooling2D, Dropout, Flatten, Dense + + +class miRBind_CNN(): + """ + Build model architecture based on the CNN model presented in miRBind paper (2022) https://doi.org/10.3390/genes13122323 + The default parameters are same as the ones used in the paper + """ + def __init__(self, cnn_num = 6, kernel_size = 5, pool_size = 2, dropout_rate = 0.3, dense_num = 2): + + x = Input(shape=(50,20,1), dtype='float32') + main_input = x + + for cnn_i in range(cnn_num): + x = Conv2D( + filters=32 * (cnn_i + 1), + kernel_size=(kernel_size, kernel_size), + padding="same", + data_format="channels_last")(x) + x = LeakyReLU()(x) + x = BatchNormalization()(x) + x = MaxPooling2D(pool_size=(pool_size, pool_size), padding='same')(x) + x = Dropout(rate=dropout_rate)(x) + + x = Flatten()(x) + + for dense_i in range(dense_num): + neurons = 32 * (cnn_num - dense_i) + x = Dense(neurons)(x) + x = LeakyReLU()(x) + x = BatchNormalization()(x) + x = Dropout(rate=dropout_rate)(x) + + main_output = Dense(1, activation='sigmoid')(x) + + model = K.Model(inputs=[main_input], outputs=[main_output], name='miRBind_CNN') + + self.model = model + + def compile_model(self, lr=0.00152): + K.backend.clear_session() + model = self.model + + opt = Adam( + learning_rate=lr, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-07, + amsgrad=False, + name="Adam") + + model.compile( + optimizer=opt, + loss='binary_crossentropy', + metrics=['accuracy'] + ) + return model \ No newline at end of file diff --git a/code/machine_learning/train/CNN_miRBind_2022/miRBind_CNN_training_orig_parameters.py b/code/machine_learning/train/CNN_miRBind_2022/miRBind_CNN_training_orig_parameters.py new file mode 100644 index 0000000..6d8e1cc --- /dev/null +++ b/code/machine_learning/train/CNN_miRBind_2022/miRBind_CNN_training_orig_parameters.py @@ -0,0 +1,139 @@ +import numpy as np +import argparse +import time +import tensorflow as tf +from tensorflow import keras as K +import matplotlib.pyplot as plt +from tensorflow.keras.utils import Sequence + +from miRBind_CNN_architecture import miRBind_CNN + + +class DataGenerator(Sequence): + # preload the encoded numpy data + def __init__(self, data_path, labels_path, dataset_size, batch_size, validation_split=0.1, + is_validation=False, shuffle=True): + # the dataset size is needed to properly load the numpy files + self.size = dataset_size + + self.data = np.memmap(data_path, dtype='float32', mode='r', shape=(self.size, 50, 20, 1)) + self.labels = np.memmap(labels_path, dtype='float32', mode='r', shape=(self.size,)) + self.batch_size = batch_size + self.shuffle = shuffle + + # Determine number of train and validation samples + self.validation_split = validation_split + self.num_samples = len(self.data) + self.num_validation_samples = int(self.num_samples * validation_split) + self.num_train_samples = self.num_samples - self.num_validation_samples + + # Determine indices for validation and training + indices = np.arange(self.num_samples) + if shuffle: + np.random.shuffle(indices) + + if is_validation: + self.indices = indices[self.num_train_samples:] + else: + self.indices = indices[:self.num_train_samples] + + # Shuffle the data initially + self.on_epoch_end() + + def __len__(self): + # Denotes the number of batches per epoch + return int(np.ceil(len(self.indices) / float(self.batch_size))) + + def __getitem__(self, idx): + # Generate one batch of data + batch_indices = self.indices[idx * self.batch_size:(idx + 1) * self.batch_size] + batch_data = self.data[batch_indices] + batch_labels = self.labels[batch_indices] + return batch_data, batch_labels + + def on_epoch_end(self): + # Updates indices after each epoch for shuffling + if self.shuffle: + np.random.shuffle(self.indices) + + +def plot_history(history, ratio): + """ + Plot history of the model training, + accuracy and loss of the training and validation set + """ + + acc = history.history['accuracy'] + val_acc = history.history['val_accuracy'] + loss = history.history['loss'] + val_loss = history.history['val_loss'] + + epochs = range(1, len(acc) + 1) + + plt.figure(figsize=(8, 6), dpi=80) + + plt.plot(epochs, acc, 'bo', label='Training acc') + plt.plot(epochs, val_acc, 'b', label='Validation acc') + plt.title('Accuracy') + plt.legend() + plt.savefig(f"training_acc_1_{ratio}.jpg") + + plt.figure() + + plt.plot(epochs, loss, 'bo', label='Training loss') + plt.plot(epochs, val_loss, 'b', label='Validation loss') + plt.title('Loss') + plt.legend() + plt.savefig(f"training_loss_1_{ratio}.jpg") + + +def train_model(data, labels, dataset_size, ratio, model_file, debug=False): + # set random state for reproducibility + np.random.seed(42) + tf.random.set_seed(42) + K.utils.set_random_seed(42) + # TODO still not fully reproducible? why? + + train_data_gen = DataGenerator(data, labels, dataset_size, batch_size=32, validation_split=0.1, + is_validation=False) + val_data_gen = DataGenerator(data, labels, dataset_size, batch_size=32, validation_split=0.1, + is_validation=True) + + model = miRBind_CNN().compile_model() + model_history = model.fit( + train_data_gen, + validation_data=val_data_gen, + epochs=10, + class_weight={0: 1, 1: ratio} + ) + + if debug: + plot_history(model_history, ratio) + + model.save(model_file) + + +def main(): + parser = argparse.ArgumentParser(description="Train CNN model on encoded miRNA x target binding matrix dataset") + parser.add_argument('--ratio', type=int, required=True, help="Ratio of pos:neg in the training dataset") + parser.add_argument('--data', type=str, required=True, help="File with the encoded dataset") + parser.add_argument('--labels', type=str, required=True, help="File with the dataset labels") + parser.add_argument('--dataset_size', type=int, required=True, + help="Number of samples in the dataset. Needed to properly load the numpy files.") + parser.add_argument('--model', type=str, required=False, help="Filename to save the trained model") + parser.add_argument('--debug', type=bool, default=False, help="Set to True to output some plots about training") + args = parser.parse_args() + + if args.model is None: + args.model = f"model_1_{args.ratio}.keras" + + start = time.time() + train_model(data=args.data, labels=args.labels, dataset_size=args.dataset_size, ratio=args.ratio, + model_file=args.model, debug=args.debug) + end = time.time() + + print("Elapsed time: ", end - start, " s.") + + +if __name__ == "__main__": + main() diff --git a/code/machine_learning/train/CNN_miRBind_2022/train_model.py b/code/machine_learning/train/CNN_miRBind_2022/train_model.py new file mode 100755 index 0000000..a1160ad --- /dev/null +++ b/code/machine_learning/train/CNN_miRBind_2022/train_model.py @@ -0,0 +1,172 @@ +import argparse +import numpy as np +import logging +import tensorflow as tf +from tensorflow import keras as K +from tensorflow.keras.optimizers import Adam +from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger +import random +import os +import sys +import matplotlib.pyplot as plt + +from utils import set_seeds, setup_logger, compile_model +from plots import plot_training_history +from code.machine_learning.data_generators import TrainDataGenerator +sys.path.append("../../../code/machine_learning/train/CNN_miRBind_2022/") +from miRBind_CNN_architecture import miRBind_CNN + + +def main(): + parser = argparse.ArgumentParser(description='Train miRBind CNN model with specified hyperparameters') + + # Data parameters + parser.add_argument('--dataset-train', type=str, required=True, + help='Path to the training dataset (numpy array)') + parser.add_argument('--labels-train', type=str, required=True, + help='Path to the training labels (numpy array)') + parser.add_argument('--dataset-size', type=int, required=True, + help='Size of the dataset (number of samples)') + parser.add_argument('--validation-split', type=float, default=0.1, + help='Validation split ratio (default: 0.1)') + + # Model architecture parameters + parser.add_argument('--cnn-num', type=int, default=6, + help='Number of CNN layers (default: 6)') + parser.add_argument('--kernel-size', type=int, default=5, + help='Kernel size for CNN layers (default: 5)') + parser.add_argument('--pool-size', type=int, default=2, + help='Pool size for MaxPooling layers (default: 2)') + parser.add_argument('--dropout-rate', type=float, default=0.3, + help='Dropout rate (default: 0.3)') + parser.add_argument('--dense-num', type=int, default=2, + help='Number of dense layers (default: 2)') + + # Training parameters + parser.add_argument('--learning-rate', type=float, default=0.00001, + help='Learning rate (default: 0.00152)') + parser.add_argument('--batch-size', type=int, default=32, + help='Batch size for training (default: 32)') + parser.add_argument('--epochs', type=int, default=30, + help='Number of epochs to train (default: 30)') + parser.add_argument('--patience', type=int, default=5, + help='Patience for early stopping (default: 5)') + parser.add_argument('--class-weight', type=float, default=1.0, + help='Weight for positive class (default: 1.0)') + + # Output parameters + parser.add_argument('--output-dir', type=str, default='./model_output', + help='Directory to save model and logs (default: ./model_output)') + parser.add_argument('--model-name', type=str, default='mirbind_cnn_model', + help='Name for the saved model (default: mirbind_cnn_model)') + parser.add_argument('--log-file', type=str, default='training.log', + help='Path to the log file (default: training.log)') + + # Misc parameters + parser.add_argument('--seed', type=int, default=42, + help='Random seed for reproducibility (default: 42)') + + args = parser.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + if not os.path.exists(args.output_dir): + raise RuntimeError(f"Failed to create output directory: {args.output_dir}") + + log_path = os.path.join(args.output_dir, args.log_file) + logger = setup_logger(log_path, 'mirbind_train') + + set_seeds(args.seed) + logger.info(f"Starting training with seed: {args.seed}") + + logger.info("Training with the following parameters:") + for arg in vars(args): + logger.info(f" {arg}: {getattr(args, arg)}") + + logger.info("Preparing data generators...") + train_data_gen = TrainDataGenerator( + args.dataset_train, + args.labels_train, + dataset_size=args.dataset_size, + batch_size=args.batch_size, + validation_split=args.validation_split, + is_validation=False + ) + + val_data_gen = TrainDataGenerator( + args.dataset_train, + args.labels_train, + dataset_size=args.dataset_size, + batch_size=args.batch_size, + validation_split=args.validation_split, + is_validation=True + ) + + logger.info("Building model...") + + model_instance = miRBind_CNN( + cnn_num=args.cnn_num, + kernel_size=args.kernel_size, + pool_size=args.pool_size, + dropout_rate=args.dropout_rate, + dense_num=args.dense_num + ).model + + model = compile_model(model_instance, lr=args.learning_rate) + + model.summary(print_fn=logger.info) + + logger.info("Setting up training callbacks...") + callbacks = [ + ModelCheckpoint( + filepath=os.path.join(args.output_dir, f"{args.model_name}_best.keras"), + monitor='val_auc', + save_best_only=True, + mode='max', + verbose=1 + ), + EarlyStopping( + monitor='val_auc', + patience=args.patience, + restore_best_weights=True, + verbose=1 + ), + CSVLogger( + os.path.join(args.output_dir, 'training_log.csv') + ) + ] + + class_weights = {0: 1, 1: args.class_weight} + + logger.info("Starting model training...") + history = model.fit( + train_data_gen, + validation_data=val_data_gen, + epochs=args.epochs, + class_weight=class_weights, + callbacks=callbacks, + verbose=1 + ) + + final_model_path = os.path.join(args.output_dir, f"{args.model_name}_final.keras") + try: + model.save(final_model_path) + except Exception as e: + logger.error(f"Failed to save model: {str(e)}") + raise + logger.info(f"Final model saved to {final_model_path}") + + logger.info("Plotting training history...") + plot_training_history(history, args.output_dir) + + logger.info("Evaluating model on validation set...") + val_metrics = model.evaluate(val_data_gen, verbose=1) + metric_names = model.metrics_names + + for name, value in zip(metric_names, val_metrics): + logger.info(f"Validation {name}: {value:.4f}") + + logger.info("Training completed successfully!") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/code/machine_learning/train/CNN_miRBind_2022/training.ipynb b/code/machine_learning/train/CNN_miRBind_2022/training.ipynb deleted file mode 100644 index cf24ebf..0000000 --- a/code/machine_learning/train/CNN_miRBind_2022/training.ipynb +++ /dev/null @@ -1,442 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "Training.ipynb", - "provenance": [], - "collapsed_sections": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } - }, - "cells": [ - { - "cell_type": "code", - "metadata": { - "id": "T6BIHgU38o2f" - }, - "source": [ - "import pandas as pd\n", - "import numpy as np\n", - "import tensorflow as tf\n", - "from tensorflow import keras as K\n", - "import matplotlib.pyplot as plt\n", - "from matplotlib.pyplot import figure\n", - "from tensorflow.keras.layers import (\n", - " BatchNormalization, LeakyReLU,\n", - " Input, Dense, Conv2D,\n", - " MaxPooling2D, Flatten, Dropout)\n", - "from tensorflow.keras.optimizers import Adam" - ], - "execution_count": 1, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "EfIQCQj2r3fV", - "outputId": "26867200-18e8-4b4e-d34d-129c7203f694" - }, - "source": [ - "!wget https://raw.githubusercontent.com/ML-Bioinfo-CEITEC/miRBind/main/Datasets/train_set_1_10_CLASH2013_paper.tsv" - ], - "execution_count": 2, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "--2022-04-24 17:22:50-- https://raw.githubusercontent.com/ML-Bioinfo-CEITEC/miRBind/main/Datasets/train_set_1_10_CLASH2013_paper.tsv\n", - "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n", - "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n", - "HTTP request sent, awaiting response... 200 OK\n", - "Length: 12518906 (12M) [text/plain]\n", - "Saving to: ‘train_set_1_10_CLASH2013_paper.tsv’\n", - "\n", - "train_set_1_10_CLAS 100%[===================>] 11.94M --.-KB/s in 0.09s \n", - "\n", - "2022-04-24 17:22:51 (130 MB/s) - ‘train_set_1_10_CLASH2013_paper.tsv’ saved [12518906/12518906]\n", - "\n" - ] - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "ALmjW7vd9u60" - }, - "source": [ - "def binding_encoding(df, tensor_dim=(50,20,1)):\n", - " \"\"\"\n", - " fun transform input database to numpy array.\n", - " \n", - " parameters:\n", - " df = Pandas df with col names \"noncodingRNA\", \"gene\", \"label\"\n", - " tensor_dim = 2d matrix shape\n", - " \n", - " output:\n", - " 2d dot matrix, labels as np array\n", - " \"\"\"\n", - " df.reset_index(inplace=True, drop=True)\n", - "\n", - " # alphabet for watson-crick interactions.\n", - " alphabet = {\"AT\": 1., \"TA\": 1., \"GC\": 1., \"CG\": 1.} \n", - "\n", - " # labels to one hot encoding\n", - " labels = df[\"label\"].to_numpy()\n", - "\n", - " # create empty main 2d matrix array\n", - " N = df.shape[0] # number of samples in df\n", - " shape_matrix_2d = (N, *tensor_dim) # 2d matrix shape \n", - " # initialize dot matrix with zeros\n", - " ohe_matrix_2d = np.zeros(shape_matrix_2d, dtype=\"float32\")\n", - "\n", - " # compile matrix with watson-crick interactions.\n", - " for index, row in df.iterrows(): \n", - " for bind_index, bind_nt in enumerate(row.gene.upper()):\n", - " \n", - " for ncrna_index, ncrna_nt in enumerate(row.noncodingRNA.upper()):\n", - " if ncrna_index >= tensor_dim[1]:\n", - " break\n", - " base_pairs = bind_nt + ncrna_nt\n", - " ohe_matrix_2d[index, bind_index, ncrna_index, 0] = alphabet.get(base_pairs, 0)\n", - " \n", - "\n", - " return ohe_matrix_2d, labels" - ], - "execution_count": 3, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "CpoytQwIElkg" - }, - "source": [ - "def make_architecture():\n", - " \"\"\"\n", - " build model architecture\n", - "\n", - " return a model object\n", - " \"\"\"\n", - " cnn_num = 6\n", - " kernel_size = 5\n", - " pool_size = 2\n", - " dropout_rate = 0.3\n", - " dense_num = 2\n", - "\n", - " x = Input(shape=(50,20,1),\n", - " dtype='float32', name='main_input'\n", - " )\n", - " main_input = x\n", - "\n", - " for cnn_i in range(cnn_num):\n", - " x = Conv2D(\n", - " filters=32 * (cnn_i + 1),\n", - " kernel_size=(kernel_size, kernel_size),\n", - " padding=\"same\",\n", - " data_format=\"channels_last\",\n", - " name=\"conv_\" + str(cnn_i + 1))(x)\n", - " x = LeakyReLU()(x)\n", - " x = BatchNormalization()(x)\n", - " x = MaxPooling2D(pool_size=(pool_size, pool_size), padding='same', name='Max_' + str(cnn_i + 1))(x)\n", - " x = Dropout(rate=dropout_rate)(x)\n", - "\n", - " x = Flatten(name='2d_matrix')(x)\n", - "\n", - " for dense_i in range(dense_num):\n", - " neurons = 32 * (cnn_num - dense_i)\n", - " x = Dense(neurons)(x)\n", - " x = LeakyReLU()(x)\n", - " x = BatchNormalization()(x)\n", - " x = Dropout(rate=dropout_rate)(x)\n", - "\n", - " main_output = Dense(1, activation='sigmoid', name='main_output')(x)\n", - "\n", - " model = K.Model(inputs=[main_input], outputs=[main_output], name='arch_00')\n", - " \n", - " return model" - ], - "execution_count": 4, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "oZ591qC0Femi" - }, - "source": [ - "def compile_model():\n", - " K.backend.clear_session()\n", - " model = make_architecture()\n", - " \n", - " opt = Adam(\n", - " learning_rate=0.00152,\n", - " beta_1=0.9,\n", - " beta_2=0.999,\n", - " epsilon=1e-07,\n", - " amsgrad=False,\n", - " name=\"Adam\")\n", - "\n", - " model.compile(\n", - " optimizer=opt,\n", - " loss='binary_crossentropy',\n", - " metrics=['accuracy']\n", - " )\n", - " return model" - ], - "execution_count": 5, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "wTB6K0lxyzcx" - }, - "source": [ - "def plot_history(history):\n", - " \"\"\"\n", - " plot history of the training of the model,\n", - " accuracy and loss of the training and validation set\n", - " \"\"\"\n", - " \n", - " acc = history.history['accuracy']\n", - " val_acc = history.history['val_accuracy']\n", - " loss = history.history['loss']\n", - " val_loss = history.history['val_loss']\n", - "\n", - " epochs = range(1, len(acc) + 1)\n", - "\n", - " plt.figure(figsize=(8, 6), dpi=80)\n", - "\n", - " plt.plot(epochs, acc, 'bo', label='Training acc')\n", - " plt.plot(epochs, val_acc, 'b', label='Validation acc')\n", - " plt.title('Accuracy')\n", - " plt.legend()\n", - " plt.figure()\n", - "\n", - " plt.plot(epochs, loss, 'bo', label='Training loss')\n", - " plt.plot(epochs, val_loss, 'b', label='Validation loss')\n", - " plt.title('Loss')\n", - " plt.legend()\n", - " plt.show()" - ], - "execution_count": 6, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "3Bh2-XEPxmZf", - "outputId": "3dffc4d8-d135-46fd-dcf4-94edb75db0cb" - }, - "source": [ - "train_df = pd.read_csv('train_set_1_10_CLASH2013_paper.tsv', sep='\\t', names=['noncodingRNA', 'gene', 'label'], header=0)\n", - "# set random state for reproducibility\n", - "RANDOM_STATE = 42\n", - "np.random.seed(RANDOM_STATE)\n", - "train_df = train_df.sample(frac=1, random_state=RANDOM_STATE)\n", - "print(train_df.head())\n", - "ohe_data = binding_encoding(train_df)\n", - "train_ohe, labels = ohe_data\n", - "print(\"Number of training samples: \", train_df.shape[0])" - ], - "execution_count": 8, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - " noncodingRNA \\\n", - "45236 ACTGCATTATGAGCACTTAA \n", - "168824 TATTGCACTTGTCCCGGCCT \n", - "2591 AAAAGCTGGGTTGAGAGGGC \n", - "76746 TCTCACACAGAAATCGCACC \n", - "63277 TGAGGTAGTAGTTTGTGCTG \n", - "\n", - " gene label \n", - "45236 GAGAAGAAATCTGGCTGGTTTGAGGGTTTCCTTTAGTTCACCCTCA... 0 \n", - "168824 GTAAATGTCTGTTTTTCATAATTGCTCTTTATATTGTGTGTTATCT... 0 \n", - "2591 GTACCCAGTAAAAACCAGAATGACCCATTGCCAGGACGCATCAAAG... 1 \n", - "76746 ACGTCGGCGCCATGCTCCAGGTACAGAGCCACATGTTGCTCCAGGC... 0 \n", - "63277 ACCAATGCCAGAGGAGCAACAGCGGCAACCTTTGGCACTGCATCCA... 0 \n", - "Number of training samples: 169312\n" - ] - } - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Lcc6OuabyVsK", - "outputId": "bb4e3fae-f08d-4d37-acfe-ff41c66e83eb" - }, - "source": [ - "model = compile_model()\n", - "model.summary()" - ], - "execution_count": 9, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Model: \"arch_00\"\n", - "_________________________________________________________________\n", - " Layer (type) Output Shape Param # \n", - "=================================================================\n", - " main_input (InputLayer) [(None, 50, 20, 1)] 0 \n", - " \n", - " conv_1 (Conv2D) (None, 50, 20, 32) 832 \n", - " \n", - " leaky_re_lu (LeakyReLU) (None, 50, 20, 32) 0 \n", - " \n", - " batch_normalization (BatchN (None, 50, 20, 32) 128 \n", - " ormalization) \n", - " \n", - " Max_1 (MaxPooling2D) (None, 25, 10, 32) 0 \n", - " \n", - " dropout (Dropout) (None, 25, 10, 32) 0 \n", - " \n", - " conv_2 (Conv2D) (None, 25, 10, 64) 51264 \n", - " \n", - " leaky_re_lu_1 (LeakyReLU) (None, 25, 10, 64) 0 \n", - " \n", - " batch_normalization_1 (Batc (None, 25, 10, 64) 256 \n", - " hNormalization) \n", - " \n", - " Max_2 (MaxPooling2D) (None, 13, 5, 64) 0 \n", - " \n", - " dropout_1 (Dropout) (None, 13, 5, 64) 0 \n", - " \n", - " conv_3 (Conv2D) (None, 13, 5, 96) 153696 \n", - " \n", - " leaky_re_lu_2 (LeakyReLU) (None, 13, 5, 96) 0 \n", - " \n", - " batch_normalization_2 (Batc (None, 13, 5, 96) 384 \n", - " hNormalization) \n", - " \n", - " Max_3 (MaxPooling2D) (None, 7, 3, 96) 0 \n", - " \n", - " dropout_2 (Dropout) (None, 7, 3, 96) 0 \n", - " \n", - " conv_4 (Conv2D) (None, 7, 3, 128) 307328 \n", - " \n", - " leaky_re_lu_3 (LeakyReLU) (None, 7, 3, 128) 0 \n", - " \n", - " batch_normalization_3 (Batc (None, 7, 3, 128) 512 \n", - " hNormalization) \n", - " \n", - " Max_4 (MaxPooling2D) (None, 4, 2, 128) 0 \n", - " \n", - " dropout_3 (Dropout) (None, 4, 2, 128) 0 \n", - " \n", - " conv_5 (Conv2D) (None, 4, 2, 160) 512160 \n", - " \n", - " leaky_re_lu_4 (LeakyReLU) (None, 4, 2, 160) 0 \n", - " \n", - " batch_normalization_4 (Batc (None, 4, 2, 160) 640 \n", - " hNormalization) \n", - " \n", - " Max_5 (MaxPooling2D) (None, 2, 1, 160) 0 \n", - " \n", - " dropout_4 (Dropout) (None, 2, 1, 160) 0 \n", - " \n", - " conv_6 (Conv2D) (None, 2, 1, 192) 768192 \n", - " \n", - " leaky_re_lu_5 (LeakyReLU) (None, 2, 1, 192) 0 \n", - " \n", - " batch_normalization_5 (Batc (None, 2, 1, 192) 768 \n", - " hNormalization) \n", - " \n", - " Max_6 (MaxPooling2D) (None, 1, 1, 192) 0 \n", - " \n", - " dropout_5 (Dropout) (None, 1, 1, 192) 0 \n", - " \n", - " 2d_matrix (Flatten) (None, 192) 0 \n", - " \n", - " dense (Dense) (None, 192) 37056 \n", - " \n", - " leaky_re_lu_6 (LeakyReLU) (None, 192) 0 \n", - " \n", - " batch_normalization_6 (Batc (None, 192) 768 \n", - " hNormalization) \n", - " \n", - " dropout_6 (Dropout) (None, 192) 0 \n", - " \n", - " dense_1 (Dense) (None, 160) 30880 \n", - " \n", - " leaky_re_lu_7 (LeakyReLU) (None, 160) 0 \n", - " \n", - " batch_normalization_7 (Batc (None, 160) 640 \n", - " hNormalization) \n", - " \n", - " dropout_7 (Dropout) (None, 160) 0 \n", - " \n", - " main_output (Dense) (None, 1) 161 \n", - " \n", - "=================================================================\n", - "Total params: 1,865,665\n", - "Trainable params: 1,863,617\n", - "Non-trainable params: 2,048\n", - "_________________________________________________________________\n" - ] - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "vIEbdJxqydNm" - }, - "source": [ - "model_history = model.fit(\n", - " train_ohe, labels,\n", - " validation_split=0.05, epochs=10,\n", - " batch_size=32,\n", - " class_weight={0 : 1, 1 : 10}\n", - " )" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "e4uUTu-k0Y2S" - }, - "source": [ - "plot_history(model_history)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "WDSXwnZmymfK" - }, - "source": [ - "model.save(\"model.h5\")" - ], - "execution_count": null, - "outputs": [] - } - ] -} \ No newline at end of file diff --git a/code/machine_learning/train/README.md b/code/machine_learning/train/README.md index 2eb5163..ef527dd 100644 --- a/code/machine_learning/train/README.md +++ b/code/machine_learning/train/README.md @@ -1 +1,8 @@ -# Training the models \ No newline at end of file +# Training the models + +### CNN miRBind 2022 +This directory aggregates models based on the miRBind CNN architecture. It was presented in this miRBind paper (2022) https://doi.org/10.3390/genes13122323 + +[miRBind CNN architecture](CNN_miRBind_2022/miRBind_CNN_architecture.py) - containing definition of the CNN model architecture + +[miRBind CNN training with original parameters](CNN_miRBind_2022/miRBind_CNN_training_orig_parameters.py) - containing training of the CNN model with the original parameters described in the paper \ No newline at end of file diff --git a/code/machine_learning/utils.py b/code/machine_learning/utils.py new file mode 100644 index 0000000..9e63709 --- /dev/null +++ b/code/machine_learning/utils.py @@ -0,0 +1,51 @@ +import numpy as np +import tensorflow as tf +import random +import logging +from tensorflow.keras.optimizers import Adam +from tensorflow import keras as K + + +def set_seeds(seed): + """Set seeds for reproducibility.""" + random.seed(seed) + np.random.seed(seed) + tf.random.set_seed(seed) + + +def compile_model(model, lr): + opt = Adam( + learning_rate=lr, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-07, + amsgrad=False, + name="Adam") + + model.compile( + optimizer=opt, + loss='binary_crossentropy', + metrics=['accuracy', K.metrics.AUC(curve='PR')] + ) + return model + + +def setup_logger(log_file, logger_name): + """Set up a logger to file and console""" + logger = logging.getLogger(logger_name) + logger.setLevel(logging.INFO) + + # Create handlers + file_handler = logging.FileHandler(log_file, 'w') + console_handler = logging.StreamHandler() + + # Create formatters and add to handlers + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + file_handler.setFormatter(formatter) + console_handler.setFormatter(formatter) + + # Add handlers to logger + logger.addHandler(file_handler) + logger.addHandler(console_handler) + + return logger \ No newline at end of file