question for logits

#4
by chenyuhe - opened

Hi, @lhallee
Thanks for your wonderful work!
I try to pull out logits from ESMplusplus_large when given my own sequence, but I find ESMplusplus_large has different logits with ESMC-600m.

ESMplusplusOutput(loss=None, logits=tensor([[[-13.5566, -17.7309, -16.1873, ..., -17.8392, -17.7335, -17.8338],
[-16.5202, -17.8747, -14.8896, ..., -17.9624, -17.7081, -17.7090],
[-16.2595, -17.8914, -15.3532, ..., -17.8488, -17.8043, -17.7690],
...,
[-17.0312, -18.1261, -15.6399, ..., -17.9292, -17.7847, -17.6585],
[-17.8950, -17.9702, -20.9157, ..., -17.8508, -17.8309, -17.6778],
[-18.3728, -17.8212, -16.3994, ..., -17.8193, -17.8356, -17.6724]]],
device='cuda:0', grad_fn=), last_hidden_state=tensor([[[-0.0021, -0.0044, 0.0021, ..., -0.0028, 0.0099, 0.0187],
[ 0.0449, -0.0159, -0.0385, ..., -0.0165, 0.0114, -0.0569],
[ 0.0229, -0.0218, 0.0014, ..., 0.0271, -0.0006, -0.0371],
...,
[-0.0048, 0.0762, -0.0214, ..., -0.0545, -0.0156, 0.0157],
[-0.0399, 0.0311, -0.0200, ..., -0.0199, 0.0026, -0.0204],
[-0.0087, 0.0201, -0.0009, ..., -0.0581, 0.0363, -0.0274]]],
device='cuda:0', grad_fn=), hidden_states=None, attentions=None)

print(logits_output.logits, logits_output.embeddings)
ForwardTrackData(sequence=tensor([[[-23.2500, -23.0000, -23.2500, ..., -23.1250, -23.2500, -23.2500],
[-24.6250, -24.6250, -24.6250, ..., -24.6250, -24.6250, -24.6250],
[-29.2500, -29.1250, -29.2500, ..., -29.2500, -29.2500, -29.2500],
...,
[-28.6250, -28.5000, -28.5000, ..., -28.5000, -28.5000, -28.5000],
[-23.5000, -23.3750, -23.5000, ..., -23.5000, -23.6250, -23.6250],
[-23.7500, -23.6250, -23.7500, ..., -23.7500, -23.7500, -23.7500]]],
device='cuda:0', dtype=torch.bfloat16), structure=None, secondary_structure=None, sasa=None, function=None) tensor([[[-0.0022, -0.0044, 0.0021, ..., -0.0028, 0.0099, 0.0189],
[ 0.0449, -0.0155, -0.0377, ..., -0.0159, 0.0112, -0.0568],
[ 0.0227, -0.0227, 0.0016, ..., 0.0269, -0.0008, -0.0371],
...,
[-0.0055, 0.0757, -0.0223, ..., -0.0548, -0.0154, 0.0158],
[-0.0393, 0.0309, -0.0196, ..., -0.0197, 0.0027, -0.0205],
[-0.0088, 0.0202, -0.0014, ..., -0.0566, 0.0364, -0.0266]]],
device='cuda:0')

Hi @chenyuhe ,

Could you paste an output using torch.allclose so we can take a closer look? We expect a very small difference since the different attention implementations can have (slighly) different outputs.
Best,
Logan

Your need to confirm your account before you can post a new comment.

Sign up or log in to comment