lhallee commited on
Commit
9c37cb4
·
verified ·
1 Parent(s): 93af4ec

Upload modeling_esm_plusplus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +97 -4
modeling_esm_plusplus.py CHANGED
@@ -931,6 +931,7 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM, EmbeddingMixi
931
  self.mse = nn.MSELoss()
932
  self.ce = nn.CrossEntropyLoss()
933
  self.bce = nn.BCEWithLogitsLoss()
 
934
  self.init_weights()
935
 
936
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
@@ -969,10 +970,7 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM, EmbeddingMixi
969
  output_hidden_states=output_hidden_states
970
  )
971
  x = output.last_hidden_state
972
- cls_features = x[:, 0, :]
973
- mean_features = self.mean_pooling(x, attention_mask)
974
- # we include mean pooling features to help with early convergence, the cost of this is basically zero
975
- features = torch.cat([cls_features, mean_features], dim=-1)
976
  logits = self.classifier(features)
977
  loss = None
978
  if labels is not None:
@@ -994,6 +992,7 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM, EmbeddingMixi
994
  loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
995
  elif self.config.problem_type == "multi_label_classification":
996
  loss = self.bce(logits, labels)
 
997
  return ESMplusplusOutput(
998
  loss=loss,
999
  logits=logits,
@@ -1197,3 +1196,97 @@ class EsmSequenceTokenizer(PreTrainedTokenizerFast):
1197
  @property
1198
  def special_token_ids(self):
1199
  return self.all_special_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
931
  self.mse = nn.MSELoss()
932
  self.ce = nn.CrossEntropyLoss()
933
  self.bce = nn.BCEWithLogitsLoss()
934
+ self.pooler = Pooler(['cls','mean'])
935
  self.init_weights()
936
 
937
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
 
970
  output_hidden_states=output_hidden_states
971
  )
972
  x = output.last_hidden_state
973
+ features = self.pooler(x, attention_mask)
 
 
 
974
  logits = self.classifier(features)
975
  loss = None
976
  if labels is not None:
 
992
  loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
993
  elif self.config.problem_type == "multi_label_classification":
994
  loss = self.bce(logits, labels)
995
+
996
  return ESMplusplusOutput(
997
  loss=loss,
998
  logits=logits,
 
1196
  @property
1197
  def special_token_ids(self):
1198
  return self.all_special_ids
1199
+
1200
+
1201
+ if __name__ == "__main__":
1202
+ # Set device to CPU for testing
1203
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1204
+ print(f"Using device: {device}")
1205
+
1206
+ # Test tokenizer
1207
+ tokenizer = EsmSequenceTokenizer()
1208
+ sample_sequence = "MQIFVKTLTGKTITLEVEPSDTIENVKAKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQKESTLHLVLRLRGG"
1209
+ encoding = tokenizer(sample_sequence, return_tensors="pt")
1210
+ print(f"Input sequence length: {len(sample_sequence)}")
1211
+ print(f"Tokenized sequence: {encoding['input_ids'].shape}")
1212
+
1213
+ # Prepare inputs
1214
+ input_ids = encoding['input_ids'].to(device)
1215
+ attention_mask = encoding['attention_mask'].to(device)
1216
+
1217
+ # Test base model with smaller config for quick testing
1218
+ print("\n=== Testing ESMplusplus Base Model ===")
1219
+ base_config = ESMplusplusConfig(
1220
+ hidden_size=384,
1221
+ num_attention_heads=6,
1222
+ num_hidden_layers=4
1223
+ )
1224
+ base_model = ESMplusplusModel(base_config).to(device)
1225
+
1226
+ with torch.no_grad():
1227
+ outputs = base_model(input_ids=input_ids, attention_mask=attention_mask)
1228
+
1229
+ print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
1230
+
1231
+ # Test embedding functionality
1232
+ print("\nTesting embedding functionality:")
1233
+ with torch.no_grad():
1234
+ embeddings = base_model._embed(input_ids, attention_mask)
1235
+ print(f"Embedding shape: {embeddings.shape}")
1236
+
1237
+ # Test masked language modeling
1238
+ print("\n=== Testing ESMplusplus For Masked LM ===")
1239
+ mlm_model = ESMplusplusForMaskedLM(base_config).to(device)
1240
+
1241
+ with torch.no_grad():
1242
+ outputs = mlm_model(input_ids=input_ids, attention_mask=attention_mask)
1243
+
1244
+ print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
1245
+ print(f"Logits shape: {outputs.logits.shape}")
1246
+
1247
+ # Test sequence classification model
1248
+ print("\n=== Testing Sequence Classification Model ===")
1249
+ classification_model = ESMplusplusForSequenceClassification(base_config).to(device)
1250
+
1251
+ with torch.no_grad():
1252
+ outputs = classification_model(input_ids=input_ids, attention_mask=attention_mask)
1253
+
1254
+ print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
1255
+ print(f"Logits shape: {outputs.logits.shape}")
1256
+
1257
+ # Test token classification model
1258
+ print("\n=== Testing Token Classification Model ===")
1259
+ token_model = ESMplusplusForTokenClassification(base_config).to(device)
1260
+
1261
+ with torch.no_grad():
1262
+ outputs = token_model(input_ids=input_ids, attention_mask=attention_mask)
1263
+
1264
+ print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
1265
+ print(f"Logits shape: {outputs.logits.shape}")
1266
+
1267
+ # Test embedding dataset functionality with a mini dataset
1268
+ print("\n=== Testing Embed Dataset Functionality ===")
1269
+ mini_dataset = [sample_sequence, sample_sequence[:50], sample_sequence[:30]]
1270
+ print(f"Creating embeddings for {len(mini_dataset)} sequences")
1271
+
1272
+ # Only run this if save path doesn't exist to avoid overwriting
1273
+ if not os.path.exists("test_embeddings.pth"):
1274
+ embeddings = mlm_model.embed_dataset(
1275
+ sequences=mini_dataset,
1276
+ tokenizer=tokenizer,
1277
+ batch_size=2,
1278
+ max_len=100,
1279
+ full_embeddings=False,
1280
+ pooling_types=['mean'],
1281
+ save_path="test_embeddings.pth"
1282
+ )
1283
+ if embeddings:
1284
+ print(f"Embedding dictionary size: {len(embeddings)}")
1285
+ for seq, emb in embeddings.items():
1286
+ print(f"Sequence length: {len(seq)}, Embedding shape: {emb.shape}")
1287
+ break
1288
+ else:
1289
+ print("Skipping embedding test as test_embeddings.pth already exists")
1290
+
1291
+ print("\nAll tests completed successfully!")
1292
+