-
Notifications
You must be signed in to change notification settings - Fork 4
Open
Description
Hi @zatchwu I want to use your trained model to sample/score signal peptides.
The following is what I came up with by going through the provided notebooks and trying to get a more straightforward sequence -> model -> prediction worflow independent of the datasets you were using. It would be great to get some feedback whether what I'm doing here is correct.
- Encoding amino acid data for the transformer
SPGEN_AA_TO_ID = {
' ': 0,
'$': 1,
'.': 2,
'A': 3,
'C': 4,
'D': 5,
'E': 6,
'F': 7,
'G': 8,
'H': 9,
'I': 10,
'K': 11,
'L': 12,
'M': 13,
'N': 14,
'P': 15,
'Q': 16,
'R': 17,
'S': 18,
'T': 19,
'U': 20,
'V': 21,
'W': 22,
'X': 23,
'Y': 24,
'Z': 25,
}
sp = [SPGEN_AA_TO_ID['$']] + [SPGEN_AA_TO_ID[x] for x in sp] + [SPGEN_AA_TO_ID['.']]
prot = [SPGEN_AA_TO_ID['$']] + [SPGEN_AA_TO_ID[x] for x in prot] + [SPGEN_AA_TO_ID['.']]
- Loading the model
def load_spgen_model():
# the weights were extracted from the .chkpt file with the same name
state_dict = torch.load('../../SPGen/remote_generation/signal_peptide/outputs/SIM99_550_12500_64_6_5_0.1_64_100_0.0001_-0.03_99_weightsonly.pt')
model = Models.Transformer(
27,
27,
107,
proj_share_weight=True,
embs_share_weight=True,
d_k=64,
d_v=64,
d_model=550,
d_word_vec=550,
d_inner_hid=1100,
n_layers=6,
n_head=5,
dropout=0.1)
model.load_state_dict(state_dict)
model.eval()
return model
- Making predictions (logits) and scoring the perplexity. I encode the data as shown in step 1, and make
prot_positions,sp_positionsmasks that are0at true positions and1at masked positions.
def get_perplexity_batch(transformer, src_seq, src_positions, tgt_seq, tgt_positions):
'''Adapted from Translator()._epoch().'''
ppls = []
loss_fn = torch.nn.CrossEntropyLoss()
pred = transformer((src_seq, src_positions), (tgt_seq, tgt_positions))
# process each seq in batch
for idx in range(len(src_seq)):
loss = loss_fn(pred[idx].view(-1, 27), tgt_seq[idx,1:].view(-1))
ppls.append(torch.exp(loss).item())
return ppls
def predict_spgen(model, loader):
with torch.no_grad():
ppl = []
for idx, batch in tqdm(enumerate(loader), total=len(loader)):
proteins, prot_positions, sps, sp_positions = batch
proteins, prot_positions, sps, sp_positions = proteins.to(device), prot_positions.to(device), sps.to(device), sp_positions.to(device)
aa_logits = model((proteins,prot_positions), (sps, sp_positions))
ppls = get_perplexity_batch(model, proteins, prot_positions, sps, sp_positions)
ppl.extend(ppls)
return np.array(ppl)
My code is running, but it is a bit hard to tell whether everything is in place or there's an error somewhere. Would be great to get some feedback - also open to any other way to make the model run on new data.
Thanks!
Metadata
Metadata
Assignees
Labels
No labels