-
Notifications
You must be signed in to change notification settings - Fork 32
Description
I have used this command for splice site prediction
python train_splice_site_prediction.py --data_dir splicesite_data --test_data_dir splicesite_test --output_dir ./outputs --ss_type donor --benchmark Danio --dataset_id db_1 --batch_size 8 --num_workers 4 --pin_memory --max_epochs 2
But after trianing it gives the below results
──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── Test metric DataLoader 0 ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── test/acc 50.0 test/f1_score 0.0 test/loss 0.6931419444322586 test/precision 0.0 test/recall 0.0 test/specificity 100.0 ─────────────────────────────────────────────────
what goes wrong here? in the paper it reports almost 99 score for this task.
I also tried multimolecule release rinalmo HF model, still getting very low scores. Please help here.
Code for multimolecule splicesite prediciton is as follows .(it include hp tuning as well)
`
TASK_NAME = "splice_site_donor"
NUM_TRIALS = 10
TIMEOUT = 3600 # 1 hour timeout
print("Step 1: Loading and preprocessing data...")
full_dataset = load_dataset("genbio-ai/rna-downstream-tasks", TASK_NAME)
filtered_dataset = full_dataset #filtered_dataset.remove_columns(["task"])
print(filtered_dataset)
tokenizer = RnaTokenizer.from_pretrained("multimolecule/rinalmo-mega")
if not tokenizer:
raise ValueError("Tokenizer failed to load. Please check the model name and internet connection.")
def tokenize_function(examples):
if 'sequences' not in examples or not isinstance(examples['sequences'], list):
raise ValueError("Dataset does not contain a 'sequence' column or it's not a list.")
return tokenizer(examples['sequences'], padding="max_length", truncation=True, max_length=128)
print("Step 2: Tokenizing the dataset...")
tokenized_dataset = filtered_dataset.map(tokenize_function, batched=True, num_proc=10, remove_columns=["sequences"])
tokenized_dataset.set_format("torch")
train_dataset = tokenized_dataset["train"] #.select(range(100))
print('Train dataset size:',len(train_dataset))
valid_dataset = tokenized_dataset["validation"] #.select(range(200))
danio_eval = tokenized_dataset["test_danio"]
fly_eval = tokenized_dataset["test_fly"]
thaliana_eval = tokenized_dataset["test_thaliana"]
worm_eval = tokenized_dataset["test_worm"]
print('Valid dataset size danio_eval:',len(danio_eval))
print('Valid dataset size fly_eval:',len(fly_eval))
print('Valid dataset size thaliana_eval:',len(thaliana_eval))
print('Valid dataset size worm_eval:',len(worm_eval))
num_labels = len(np.unique(train_dataset['labels']))
print(f"Number of labels detected: {num_labels}")
def compute_metrics(eval_pred):
predictions, labels = eval_pred.predictions, eval_pred.label_ids
if isinstance(predictions, tuple):
predictions = predictions[0]
predictions = np.argmax(predictions, axis=1)
f1 = f1_score(labels, predictions, average='macro')
mcc = matthews_corrcoef(labels, predictions)
precision = precision_score(labels, predictions,average='macro', zero_division=0)
recall = recall_score(labels, predictions, average='macro',zero_division=0)
return {
"f1": f1,
"mcc": mcc,
"precision": precision,
"recall": recall
}
def objective(trial):
learning_rate = trial.suggest_float("learning_rate", 5e-6, 5e-5, log=True)
num_train_epochs = trial.suggest_int("num_train_epochs", 1, 20)
config = AutoConfig.from_pretrained(MODEL_NAME, trust_remote_code=True)
config.num_labels = num_labels
config._attn_implementation = "eager"
model = RiNALMoForSequencePrediction.from_pretrained("multimolecule/rinalmo-mega")
early_stop_callback = EarlyStoppingCallback(early_stopping_patience=3)
training_args = TrainingArguments(
output_dir=f"./results/{TASK_NAME}/trial_{trial.number}",
num_train_epochs=num_train_epochs,
# per_device_train_batch_size=per_device_train_batch_size,
# per_device_eval_batch_size=per_device_train_batch_size * 2,
learning_rate=learning_rate,
warmup_ratio=0.1,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=100,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better= False,
report_to="none", torch_compile=False,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=valid_dataset,
compute_metrics=compute_metrics,
callbacks=[early_stop_callback],
)
trainer.train()
eval_results = trainer.evaluate()
return eval_results["eval_f1"]
if name == "main":
study_name = f"{TASK_NAME}_optimization"
storage_name = f"sqlite:///{study_name}.db"
study = optuna.create_study(direction="maximize", study_name=study_name, storage=storage_name, load_if_exists=True)
study.optimize(objective, n_trials=NUM_TRIALS, timeout=TIMEOUT)
print("\n=======================================================")
print("Optimization finished.")
print("Best hyperparameters found: ", study.best_params)
print("Best F1 score: ", study.best_value)
print("=======================================================")
best_params = study.best_params
config = AutoConfig.from_pretrained(MODEL_NAME, trust_remote_code=True)
config.num_labels = num_labels
config._attn_implementation = "eager"
import time
start_time = time.time()
final_model = RiNALMoForSequencePrediction.from_pretrained("multimolecule/rinalmo-mega")
early_stop_callback = EarlyStoppingCallback(early_stopping_patience=3)
final_training_args = TrainingArguments(
output_dir=f"./final_model/{TASK_NAME}",
num_train_epochs=best_params["num_train_epochs"],
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
learning_rate=best_params["learning_rate"],
warmup_ratio=0.1,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=100,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
report_to="none",
greater_is_better =False,
torch_compile=False,
)
final_trainer = Trainer(
model=final_model,
args=final_training_args,
train_dataset=train_dataset,
eval_dataset=valid_dataset,
compute_metrics=compute_metrics,
callbacks=[early_stop_callback],
)
final_trainer.train()
final_evaluation_results = final_trainer.evaluate()
metrics_danio = final_trainer.evaluate(eval_dataset=danio_eval)
metrics_fly = final_trainer.evaluate(eval_dataset=fly_eval)
metrics_thaliana = final_trainer.evaluate(eval_dataset=thaliana_eval)
metrics_worm = final_trainer.evaluate(eval_dataset=worm_eval)
print('Danio metrics', metrics_danio)
print('Fly metrics', metrics_fly)
print('Thaliana metrics', metrics_thaliana)
print('Worm metrics', metrics_worm)