sgoel30 commited on
Commit
6673cf0
·
verified ·
1 Parent(s): d8ed92a

Delete utils

Browse files
Files changed (2) hide show
  1. utils/data_loader.py +0 -43
  2. utils/esm_utils.py +0 -15
utils/data_loader.py DELETED
@@ -1,43 +0,0 @@
1
- import pandas as pd
2
- import torch
3
- from torch.utils.data import Dataset, DataLoader
4
- from torch.nn.utils.rnn import pad_sequence
5
- from esm_utils import get_latents, load_esm2_model
6
- import config
7
-
8
- class ProteinDataset(Dataset):
9
- def __init__(self, csv_file, tokenizer, model):
10
- self.data = pd.read_csv(csv_file).head(4)
11
- self.tokenizer = tokenizer
12
- self.model = model
13
-
14
- def __len__(self):
15
- return len(self.data)
16
-
17
- def __getitem__(self, idx):
18
- sequence = self.data.iloc[idx]['Sequence']
19
- latents = get_latents(self.model, self.tokenizer, sequence)
20
-
21
- attention_mask = torch.ones_like(latents)
22
- attention_mask = torch.mean(attention_mask, dim=-1)
23
-
24
- return latents, attention_mask
25
-
26
- def collate_fn(batch):
27
- latents, attention_mask = zip(*batch)
28
- latents_padded = pad_sequence([torch.tensor(latent) for latent in latents], batch_first=True, padding_value=0)
29
- attention_mask_padded = pad_sequence([torch.tensor(mask) for mask in attention_mask], batch_first=True, padding_value=0)
30
- return latents_padded, attention_mask_padded
31
-
32
- def get_dataloaders(config):
33
- tokenizer, masked_model, embedding_model = load_esm2_model(config.MODEL_NAME)
34
-
35
- train_dataset = ProteinDataset(config.Loader.DATA_PATH + "/train.csv", tokenizer, embedding_model)
36
- val_dataset = ProteinDataset(config.Loader.DATA_PATH + "/val.csv", tokenizer, embedding_model)
37
- test_dataset = ProteinDataset(config.Loader.DATA_PATH + "/test.csv", tokenizer, embedding_model)
38
-
39
- train_loader = DataLoader(train_dataset, batch_size=config.Loader.BATCH_SIZE, num_workers=0, shuffle=True, collate_fn=collate_fn)
40
- val_loader = DataLoader(val_dataset, batch_size=config.Loader.BATCH_SIZE, num_workers=0, shuffle=False, collate_fn=collate_fn)
41
- test_loader = DataLoader(test_dataset, batch_size=config.Loader.BATCH_SIZE, num_workers=0, shuffle=False, collate_fn=collate_fn)
42
-
43
- return train_loader, val_loader, test_loader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/esm_utils.py DELETED
@@ -1,15 +0,0 @@
1
- import torch
2
- import config
3
- from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM
4
-
5
- def load_esm2_model(model_name):
6
- tokenizer = AutoTokenizer.from_pretrained(model_name)
7
- masked_model = AutoModelForMaskedLM.from_pretrained(model_name)
8
- embedding_model = AutoModel.from_pretrained(model_name)
9
- return tokenizer, masked_model, embedding_model
10
-
11
- def get_latents(model, tokenizer, sequence):
12
- inputs = tokenizer(sequence, return_tensors="pt")
13
- with torch.no_grad():
14
- outputs = model(**inputs)
15
- return outputs.last_hidden_state.squeeze(0)