Upload modeling_esm_plusplus.py with huggingface_hub
Browse files- modeling_esm_plusplus.py +97 -4
modeling_esm_plusplus.py
CHANGED
@@ -931,6 +931,7 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM, EmbeddingMixi
|
|
931 |
self.mse = nn.MSELoss()
|
932 |
self.ce = nn.CrossEntropyLoss()
|
933 |
self.bce = nn.BCEWithLogitsLoss()
|
|
|
934 |
self.init_weights()
|
935 |
|
936 |
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
@@ -969,10 +970,7 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM, EmbeddingMixi
|
|
969 |
output_hidden_states=output_hidden_states
|
970 |
)
|
971 |
x = output.last_hidden_state
|
972 |
-
|
973 |
-
mean_features = self.mean_pooling(x, attention_mask)
|
974 |
-
# we include mean pooling features to help with early convergence, the cost of this is basically zero
|
975 |
-
features = torch.cat([cls_features, mean_features], dim=-1)
|
976 |
logits = self.classifier(features)
|
977 |
loss = None
|
978 |
if labels is not None:
|
@@ -994,6 +992,7 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM, EmbeddingMixi
|
|
994 |
loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
|
995 |
elif self.config.problem_type == "multi_label_classification":
|
996 |
loss = self.bce(logits, labels)
|
|
|
997 |
return ESMplusplusOutput(
|
998 |
loss=loss,
|
999 |
logits=logits,
|
@@ -1197,3 +1196,97 @@ class EsmSequenceTokenizer(PreTrainedTokenizerFast):
|
|
1197 |
@property
|
1198 |
def special_token_ids(self):
|
1199 |
return self.all_special_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
931 |
self.mse = nn.MSELoss()
|
932 |
self.ce = nn.CrossEntropyLoss()
|
933 |
self.bce = nn.BCEWithLogitsLoss()
|
934 |
+
self.pooler = Pooler(['cls','mean'])
|
935 |
self.init_weights()
|
936 |
|
937 |
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
970 |
output_hidden_states=output_hidden_states
|
971 |
)
|
972 |
x = output.last_hidden_state
|
973 |
+
features = self.pooler(x, attention_mask)
|
|
|
|
|
|
|
974 |
logits = self.classifier(features)
|
975 |
loss = None
|
976 |
if labels is not None:
|
|
|
992 |
loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
|
993 |
elif self.config.problem_type == "multi_label_classification":
|
994 |
loss = self.bce(logits, labels)
|
995 |
+
|
996 |
return ESMplusplusOutput(
|
997 |
loss=loss,
|
998 |
logits=logits,
|
|
|
1196 |
@property
|
1197 |
def special_token_ids(self):
|
1198 |
return self.all_special_ids
|
1199 |
+
|
1200 |
+
|
1201 |
+
if __name__ == "__main__":
|
1202 |
+
# Set device to CPU for testing
|
1203 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
1204 |
+
print(f"Using device: {device}")
|
1205 |
+
|
1206 |
+
# Test tokenizer
|
1207 |
+
tokenizer = EsmSequenceTokenizer()
|
1208 |
+
sample_sequence = "MQIFVKTLTGKTITLEVEPSDTIENVKAKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQKESTLHLVLRLRGG"
|
1209 |
+
encoding = tokenizer(sample_sequence, return_tensors="pt")
|
1210 |
+
print(f"Input sequence length: {len(sample_sequence)}")
|
1211 |
+
print(f"Tokenized sequence: {encoding['input_ids'].shape}")
|
1212 |
+
|
1213 |
+
# Prepare inputs
|
1214 |
+
input_ids = encoding['input_ids'].to(device)
|
1215 |
+
attention_mask = encoding['attention_mask'].to(device)
|
1216 |
+
|
1217 |
+
# Test base model with smaller config for quick testing
|
1218 |
+
print("\n=== Testing ESMplusplus Base Model ===")
|
1219 |
+
base_config = ESMplusplusConfig(
|
1220 |
+
hidden_size=384,
|
1221 |
+
num_attention_heads=6,
|
1222 |
+
num_hidden_layers=4
|
1223 |
+
)
|
1224 |
+
base_model = ESMplusplusModel(base_config).to(device)
|
1225 |
+
|
1226 |
+
with torch.no_grad():
|
1227 |
+
outputs = base_model(input_ids=input_ids, attention_mask=attention_mask)
|
1228 |
+
|
1229 |
+
print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
|
1230 |
+
|
1231 |
+
# Test embedding functionality
|
1232 |
+
print("\nTesting embedding functionality:")
|
1233 |
+
with torch.no_grad():
|
1234 |
+
embeddings = base_model._embed(input_ids, attention_mask)
|
1235 |
+
print(f"Embedding shape: {embeddings.shape}")
|
1236 |
+
|
1237 |
+
# Test masked language modeling
|
1238 |
+
print("\n=== Testing ESMplusplus For Masked LM ===")
|
1239 |
+
mlm_model = ESMplusplusForMaskedLM(base_config).to(device)
|
1240 |
+
|
1241 |
+
with torch.no_grad():
|
1242 |
+
outputs = mlm_model(input_ids=input_ids, attention_mask=attention_mask)
|
1243 |
+
|
1244 |
+
print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
|
1245 |
+
print(f"Logits shape: {outputs.logits.shape}")
|
1246 |
+
|
1247 |
+
# Test sequence classification model
|
1248 |
+
print("\n=== Testing Sequence Classification Model ===")
|
1249 |
+
classification_model = ESMplusplusForSequenceClassification(base_config).to(device)
|
1250 |
+
|
1251 |
+
with torch.no_grad():
|
1252 |
+
outputs = classification_model(input_ids=input_ids, attention_mask=attention_mask)
|
1253 |
+
|
1254 |
+
print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
|
1255 |
+
print(f"Logits shape: {outputs.logits.shape}")
|
1256 |
+
|
1257 |
+
# Test token classification model
|
1258 |
+
print("\n=== Testing Token Classification Model ===")
|
1259 |
+
token_model = ESMplusplusForTokenClassification(base_config).to(device)
|
1260 |
+
|
1261 |
+
with torch.no_grad():
|
1262 |
+
outputs = token_model(input_ids=input_ids, attention_mask=attention_mask)
|
1263 |
+
|
1264 |
+
print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
|
1265 |
+
print(f"Logits shape: {outputs.logits.shape}")
|
1266 |
+
|
1267 |
+
# Test embedding dataset functionality with a mini dataset
|
1268 |
+
print("\n=== Testing Embed Dataset Functionality ===")
|
1269 |
+
mini_dataset = [sample_sequence, sample_sequence[:50], sample_sequence[:30]]
|
1270 |
+
print(f"Creating embeddings for {len(mini_dataset)} sequences")
|
1271 |
+
|
1272 |
+
# Only run this if save path doesn't exist to avoid overwriting
|
1273 |
+
if not os.path.exists("test_embeddings.pth"):
|
1274 |
+
embeddings = mlm_model.embed_dataset(
|
1275 |
+
sequences=mini_dataset,
|
1276 |
+
tokenizer=tokenizer,
|
1277 |
+
batch_size=2,
|
1278 |
+
max_len=100,
|
1279 |
+
full_embeddings=False,
|
1280 |
+
pooling_types=['mean'],
|
1281 |
+
save_path="test_embeddings.pth"
|
1282 |
+
)
|
1283 |
+
if embeddings:
|
1284 |
+
print(f"Embedding dictionary size: {len(embeddings)}")
|
1285 |
+
for seq, emb in embeddings.items():
|
1286 |
+
print(f"Sequence length: {len(seq)}, Embedding shape: {emb.shape}")
|
1287 |
+
break
|
1288 |
+
else:
|
1289 |
+
print("Skipping embedding test as test_embeddings.pth already exists")
|
1290 |
+
|
1291 |
+
print("\nAll tests completed successfully!")
|
1292 |
+
|