question for logits

#4
by chenyuhe - opened

Hi, @lhallee
Thanks for your wonderful work!
I try to pull out logits from ESMplusplus_large when given my own sequence, but I find ESMplusplus_large has different logits with ESMC-600m.

ESMplusplusOutput(loss=None, logits=tensor([[[-13.5566, -17.7309, -16.1873, ..., -17.8392, -17.7335, -17.8338],
[-16.5202, -17.8747, -14.8896, ..., -17.9624, -17.7081, -17.7090],
[-16.2595, -17.8914, -15.3532, ..., -17.8488, -17.8043, -17.7690],
...,
[-17.0312, -18.1261, -15.6399, ..., -17.9292, -17.7847, -17.6585],
[-17.8950, -17.9702, -20.9157, ..., -17.8508, -17.8309, -17.6778],
[-18.3728, -17.8212, -16.3994, ..., -17.8193, -17.8356, -17.6724]]],
device='cuda:0', grad_fn=), last_hidden_state=tensor([[[-0.0021, -0.0044, 0.0021, ..., -0.0028, 0.0099, 0.0187],
[ 0.0449, -0.0159, -0.0385, ..., -0.0165, 0.0114, -0.0569],
[ 0.0229, -0.0218, 0.0014, ..., 0.0271, -0.0006, -0.0371],
...,
[-0.0048, 0.0762, -0.0214, ..., -0.0545, -0.0156, 0.0157],
[-0.0399, 0.0311, -0.0200, ..., -0.0199, 0.0026, -0.0204],
[-0.0087, 0.0201, -0.0009, ..., -0.0581, 0.0363, -0.0274]]],
device='cuda:0', grad_fn=), hidden_states=None, attentions=None)

print(logits_output.logits, logits_output.embeddings)
ForwardTrackData(sequence=tensor([[[-23.2500, -23.0000, -23.2500, ..., -23.1250, -23.2500, -23.2500],
[-24.6250, -24.6250, -24.6250, ..., -24.6250, -24.6250, -24.6250],
[-29.2500, -29.1250, -29.2500, ..., -29.2500, -29.2500, -29.2500],
...,
[-28.6250, -28.5000, -28.5000, ..., -28.5000, -28.5000, -28.5000],
[-23.5000, -23.3750, -23.5000, ..., -23.5000, -23.6250, -23.6250],
[-23.7500, -23.6250, -23.7500, ..., -23.7500, -23.7500, -23.7500]]],
device='cuda:0', dtype=torch.bfloat16), structure=None, secondary_structure=None, sasa=None, function=None) tensor([[[-0.0022, -0.0044, 0.0021, ..., -0.0028, 0.0099, 0.0189],
[ 0.0449, -0.0155, -0.0377, ..., -0.0159, 0.0112, -0.0568],
[ 0.0227, -0.0227, 0.0016, ..., 0.0269, -0.0008, -0.0371],
...,
[-0.0055, 0.0757, -0.0223, ..., -0.0548, -0.0154, 0.0158],
[-0.0393, 0.0309, -0.0196, ..., -0.0197, 0.0027, -0.0205],
[-0.0088, 0.0202, -0.0014, ..., -0.0566, 0.0364, -0.0266]]],
device='cuda:0')

Synthyra org
edited Mar 18

Hi @chenyuhe ,

Could you paste an output using torch.allclose so we can take a closer look? We expect a very small difference since the different attention implementations can have (slighly) different outputs.
Best,
Logan

Synthyra org

Hi @chenyuhe ,

There was recently a bug introduced by Huggingface chainging how their weight tieing works. This probably causes the logits issue above. It should be addressed now, please reach out if you have other questions.
Best,
Logan

lhallee changed discussion status to closed

Sign up or log in to comment