Skip to content

Splice-site prediction output mismatch #40

@naimavahab

Description

@naimavahab

I have used this command for splice site prediction

python train_splice_site_prediction.py --data_dir splicesite_data --test_data_dir splicesite_test --output_dir ./outputs --ss_type donor --benchmark Danio --dataset_id db_1 --batch_size 8 --num_workers 4 --pin_memory --max_epochs 2

But after trianing it gives the below results
──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── Test metric DataLoader 0 ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── test/acc 50.0 test/f1_score 0.0 test/loss 0.6931419444322586 test/precision 0.0 test/recall 0.0 test/specificity 100.0 ─────────────────────────────────────────────────

what goes wrong here? in the paper it reports almost 99 score for this task.
I also tried multimolecule release rinalmo HF model, still getting very low scores. Please help here.

Code for multimolecule splicesite prediciton is as follows .(it include hp tuning as well)

`

TASK_NAME = "splice_site_donor"
NUM_TRIALS = 10
TIMEOUT = 3600 # 1 hour timeout

print("Step 1: Loading and preprocessing data...")
full_dataset = load_dataset("genbio-ai/rna-downstream-tasks", TASK_NAME)
filtered_dataset = full_dataset #filtered_dataset.remove_columns(["task"])
print(filtered_dataset)

tokenizer = RnaTokenizer.from_pretrained("multimolecule/rinalmo-mega")
if not tokenizer:
raise ValueError("Tokenizer failed to load. Please check the model name and internet connection.")

def tokenize_function(examples):

if 'sequences' not in examples or not isinstance(examples['sequences'], list):
    raise ValueError("Dataset does not contain a 'sequence' column or it's not a list.")

return tokenizer(examples['sequences'], padding="max_length", truncation=True, max_length=128)

print("Step 2: Tokenizing the dataset...")

tokenized_dataset = filtered_dataset.map(tokenize_function, batched=True, num_proc=10, remove_columns=["sequences"])
tokenized_dataset.set_format("torch")

train_dataset = tokenized_dataset["train"] #.select(range(100))
print('Train dataset size:',len(train_dataset))
valid_dataset = tokenized_dataset["validation"] #.select(range(200))

danio_eval = tokenized_dataset["test_danio"]
fly_eval = tokenized_dataset["test_fly"]
thaliana_eval = tokenized_dataset["test_thaliana"]
worm_eval = tokenized_dataset["test_worm"]
print('Valid dataset size danio_eval:',len(danio_eval))
print('Valid dataset size fly_eval:',len(fly_eval))
print('Valid dataset size thaliana_eval:',len(thaliana_eval))
print('Valid dataset size worm_eval:',len(worm_eval))

num_labels = len(np.unique(train_dataset['labels']))
print(f"Number of labels detected: {num_labels}")

def compute_metrics(eval_pred):
predictions, labels = eval_pred.predictions, eval_pred.label_ids
if isinstance(predictions, tuple):
predictions = predictions[0]
predictions = np.argmax(predictions, axis=1)
f1 = f1_score(labels, predictions, average='macro')
mcc = matthews_corrcoef(labels, predictions)
precision = precision_score(labels, predictions,average='macro', zero_division=0)
recall = recall_score(labels, predictions, average='macro',zero_division=0)
return {
"f1": f1,
"mcc": mcc,
"precision": precision,
"recall": recall
}

def objective(trial):
learning_rate = trial.suggest_float("learning_rate", 5e-6, 5e-5, log=True)
num_train_epochs = trial.suggest_int("num_train_epochs", 1, 20)

config = AutoConfig.from_pretrained(MODEL_NAME, trust_remote_code=True)
config.num_labels = num_labels
config._attn_implementation = "eager"

model = RiNALMoForSequencePrediction.from_pretrained("multimolecule/rinalmo-mega")

early_stop_callback = EarlyStoppingCallback(early_stopping_patience=3)


training_args = TrainingArguments(
    output_dir=f"./results/{TASK_NAME}/trial_{trial.number}",
    num_train_epochs=num_train_epochs,
   # per_device_train_batch_size=per_device_train_batch_size,
   # per_device_eval_batch_size=per_device_train_batch_size * 2,
    learning_rate=learning_rate,
    warmup_ratio=0.1,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=100,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better= False,
    report_to="none",  torch_compile=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    compute_metrics=compute_metrics,
        callbacks=[early_stop_callback],

)

trainer.train()
eval_results = trainer.evaluate()

return eval_results["eval_f1"]

if name == "main":

study_name = f"{TASK_NAME}_optimization"
storage_name = f"sqlite:///{study_name}.db"

study = optuna.create_study(direction="maximize", study_name=study_name, storage=storage_name, load_if_exists=True)
study.optimize(objective, n_trials=NUM_TRIALS, timeout=TIMEOUT)

print("\n=======================================================")
print("Optimization finished.")
print("Best hyperparameters found: ", study.best_params)
print("Best F1 score: ", study.best_value)
print("=======================================================")

best_params = study.best_params       


config = AutoConfig.from_pretrained(MODEL_NAME, trust_remote_code=True)
config.num_labels = num_labels
config._attn_implementation = "eager"
import time
start_time = time.time()
final_model = RiNALMoForSequencePrediction.from_pretrained("multimolecule/rinalmo-mega")

early_stop_callback = EarlyStoppingCallback(early_stopping_patience=3)

final_training_args = TrainingArguments(
    output_dir=f"./final_model/{TASK_NAME}",
    num_train_epochs=best_params["num_train_epochs"],
    per_device_train_batch_size=8, 
    per_device_eval_batch_size=8, 
    learning_rate=best_params["learning_rate"],
    warmup_ratio=0.1,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=100,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    report_to="none",
    greater_is_better =False,
        torch_compile=False,

)
final_trainer = Trainer(
    model=final_model,
    args=final_training_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    compute_metrics=compute_metrics,
        callbacks=[early_stop_callback],

)

final_trainer.train()
final_evaluation_results = final_trainer.evaluate()
metrics_danio = final_trainer.evaluate(eval_dataset=danio_eval)
metrics_fly = final_trainer.evaluate(eval_dataset=fly_eval)
metrics_thaliana = final_trainer.evaluate(eval_dataset=thaliana_eval)
metrics_worm = final_trainer.evaluate(eval_dataset=worm_eval)
print('Danio metrics', metrics_danio)
print('Fly metrics', metrics_fly)
print('Thaliana metrics', metrics_thaliana)
print('Worm metrics', metrics_worm)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions