sgoel30 commited on
Commit
d6c63a1
·
verified ·
1 Parent(s): 60ee22e

Upload 2 files

Browse files
Files changed (2) hide show
  1. utils/data_loader.py +5 -5
  2. utils/esm_utils.py +6 -5
utils/data_loader.py CHANGED
@@ -7,7 +7,7 @@ import config
7
 
8
  class ProteinDataset(Dataset):
9
  def __init__(self, csv_file, tokenizer, model):
10
- self.data = pd.read_csv(csv_file)
11
  self.tokenizer = tokenizer
12
  self.model = model
13
 
@@ -30,11 +30,11 @@ def collate_fn(batch):
30
  return latents_padded, attention_mask_padded
31
 
32
  def get_dataloaders(config):
33
- tokenizer, model = load_esm2_model(config.MODEL_NAME)
34
 
35
- train_dataset = ProteinDataset(config.Loader.DATA_PATH + "/train.csv", tokenizer, model)
36
- val_dataset = ProteinDataset(config.Loader.DATA_PATH + "/val.csv", tokenizer, model)
37
- test_dataset = ProteinDataset(config.Loader.DATA_PATH + "/test.csv", tokenizer, model)
38
 
39
  train_loader = DataLoader(train_dataset, batch_size=config.Loader.BATCH_SIZE, num_workers=0, shuffle=True, collate_fn=collate_fn)
40
  val_loader = DataLoader(val_dataset, batch_size=config.Loader.BATCH_SIZE, num_workers=0, shuffle=False, collate_fn=collate_fn)
 
7
 
8
  class ProteinDataset(Dataset):
9
  def __init__(self, csv_file, tokenizer, model):
10
+ self.data = pd.read_csv(csv_file).head(4)
11
  self.tokenizer = tokenizer
12
  self.model = model
13
 
 
30
  return latents_padded, attention_mask_padded
31
 
32
  def get_dataloaders(config):
33
+ tokenizer, masked_model, embedding_model = load_esm2_model(config.MODEL_NAME)
34
 
35
+ train_dataset = ProteinDataset(config.Loader.DATA_PATH + "/train.csv", tokenizer, embedding_model)
36
+ val_dataset = ProteinDataset(config.Loader.DATA_PATH + "/val.csv", tokenizer, embedding_model)
37
+ test_dataset = ProteinDataset(config.Loader.DATA_PATH + "/test.csv", tokenizer, embedding_model)
38
 
39
  train_loader = DataLoader(train_dataset, batch_size=config.Loader.BATCH_SIZE, num_workers=0, shuffle=True, collate_fn=collate_fn)
40
  val_loader = DataLoader(val_dataset, batch_size=config.Loader.BATCH_SIZE, num_workers=0, shuffle=False, collate_fn=collate_fn)
utils/esm_utils.py CHANGED
@@ -1,14 +1,15 @@
1
  import torch
2
- from transformers import AutoTokenizer, AutoModel
 
3
 
4
  def load_esm2_model(model_name):
5
  tokenizer = AutoTokenizer.from_pretrained(model_name)
6
- model = AutoModel.from_pretrained(model_name)
7
- return tokenizer, model
 
8
 
9
  def get_latents(model, tokenizer, sequence):
10
  inputs = tokenizer(sequence, return_tensors="pt")
11
  with torch.no_grad():
12
  outputs = model(**inputs)
13
- return outputs.last_hidden_state.squeeze(0)
14
-
 
1
  import torch
2
+ import config
3
+ from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM
4
 
5
  def load_esm2_model(model_name):
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
+ masked_model = AutoModelForMaskedLM.from_pretrained(model_name)
8
+ embedding_model = AutoModel.from_pretrained(model_name)
9
+ return tokenizer, masked_model, embedding_model
10
 
11
  def get_latents(model, tokenizer, sequence):
12
  inputs = tokenizer(sequence, return_tensors="pt")
13
  with torch.no_grad():
14
  outputs = model(**inputs)
15
+ return outputs.last_hidden_state.squeeze(0)