The downstream classifier as defined as follows has a simple linear model that can map to any number between negative inf to positive infinity. I switched to nn.bce_with_logits().
'''
class target_classifier(nn.Module):
def init(self, configs):
super(target_classifier, self).init()
self.logits = nn.Linear(2*128, 64)
self.logits_simple = nn.Linear(64, configs.num_classes_target)
def forward(self, emb):
emb_flat = emb.reshape(emb.shape[0], -1)
emb = torch.sigmoid(self.logits(emb_flat))
pred = self.logits_simple(emb)
return pred
'''