import torch import config import math import sys import os from tqdm import tqdm from torch.optim import Adam from torch.optim.lr_scheduler import CosineAnnealingLR from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer, AutoConfig from pretrained_models import load_esm2_model from model import MembraneMLM, MembraneTokenizer from data_loader import get_dataloaders def save_hyperparams(ckpt_dir): hyperparms_txt_file = os.path.join(ckpt_dir, "hyperparameters.txt") with open(hyperparms_txt_file, 'w') as f: for k, v in vars(config).items(): if k.isupper(): f.write(f"{k}: {v}\n") def train_and_validate(model, optimizer, device, train_loader, val_loader, num_epochs, ckpt_dir): best_val_loss = float('inf') for epoch in range(num_epochs): print(f"EPOCH {epoch+1}/{num_epochs}") sys.stderr.flush() total_train_loss = 0.0 weighted_total_train_loss = 0.0 total_masked_train_tokens = 0 model.train() train_update_interval = len(train_loader) // 4 with tqdm(enumerate(train_loader), desc="Training batch", total=len(train_loader), leave=True, position=0, ncols=100) as trainbar: for step, inputs in trainbar: inputs = {k: v.to(device) for k, v in inputs.items()} optimizer.zero_grad() outputs = model(**inputs) train_loss = outputs.loss train_loss.backward() optimizer.step() num_mask_tokens = (inputs["input_ids"] == tokenizer.mask_token_id).sum().item() total_masked_train_tokens += num_mask_tokens total_train_loss += train_loss.item() weighted_total_train_loss += train_loss.item() * num_mask_tokens if (step+1) % train_update_interval == 0: trainbar.update(train_update_interval) avg_train_loss = total_train_loss / len(train_loader) avg_train_neg_log_likelihood = weighted_total_train_loss / total_masked_train_tokens train_perplexity = math.exp(avg_train_neg_log_likelihood) # Save model every epoch train_ckpt_path = os.path.join(config.CKPT_DIR, f'epoch{epoch+1}') model.save_model(train_ckpt_path) save_hyperparams(train_ckpt_path) # Validate model if val_loader: model.eval() total_val_loss = 0.0 weighted_total_val_loss = 0.0 total_masked_val_tokens = 0.0 with torch.no_grad(): val_update_interval = len(val_loader) // 4 with tqdm(enumerate(val_loader), desc='Validiation batch', total=len(val_loader), leave=True, position=0) as valbar: for step, inputs in valbar: inputs = {k: v.to(device) for k, v in inputs.items()} val_loss = model(**inputs).loss.item() num_mask_tokens = (inputs['input_ids'] == tokenizer.mask_token_id).sum().item() total_masked_val_tokens += num_mask_tokens total_val_loss += val_loss weighted_total_val_loss += val_loss * num_mask_tokens if (step+1) % val_update_interval == 0: valbar.update(val_update_interval) avg_val_loss = total_val_loss / len(val_loader) avg_val_neg_log_likelihood = weighted_total_val_loss / total_masked_val_tokens val_perplexity = math.exp(avg_val_neg_log_likelihood) # Save the best model based on validation loss if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss val_ckpt_path = os.path.join(config.CKPT_DIR, "best_model_epoch") model.save_model(val_ckpt_path) save_hyperparams(val_ckpt_path) print(f"Average train loss: {avg_train_loss}") print(f"Average train perplexity: {train_perplexity}\n") sys.stdout.flush() print(f"Average validation loss: {avg_val_loss}") print(f"Average validation perplexity: {val_perplexity}\n") sys.stdout.flush() return avg_train_loss, train_perplexity, avg_val_loss, val_perplexity def test(model, test_loader, device): model.to(device).eval() total_test_loss = 0.0 weighted_total_test_loss = 0.0 total_masked_test_tokens = 0.0 with torch.no_grad(): for step, inputs in enumerate(test_loader): inputs = {k: v.to(device) for k, v in inputs.items()} outputs = model(**inputs) test_loss = outputs.loss.item() num_mask_tokens = (inputs["input_ids"] == tokenizer.mask_token_id).sum().item() total_masked_test_tokens += num_mask_tokens total_test_loss += test_loss weighted_total_test_loss += test_loss * num_mask_tokens avg_test_loss = total_test_loss / len(test_loader) avg_test_neg_log_likilehood = weighted_total_test_loss / total_masked_test_tokens test_perplexity = math.exp(avg_test_neg_log_likilehood) return avg_test_loss, test_perplexity if __name__ == "__main__": device = torch.device('cuda' if torch.cuda.is_available() else "cpu") print(device) model = MembraneMLM() model.to(device) model.freeze_model() model.unfreeze_n_layers() tokenizer = model.tokenizer train_loader, val_loader, test_loader = get_dataloaders(config) optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.LEARNING_RATE) # Train and test the model avg_train_loss, train_ppl, avg_val_loss, val_ppl = train_and_validate(model, optimizer, device, train_loader, val_loader, config.NUM_EPOCHS, config.CKPT_DIR) avg_test_loss, test_ppl = test(model, test_loader, device) results_dict = {"Average train loss": avg_train_loss, "Average train perplexity": train_ppl, "Average val loss": avg_val_loss, "Average val perplexity": val_ppl, "Average test loss": avg_test_loss, "Average test perplexity": test_ppl, } print("TRAIN AND TEST RESULTS") for k, v in results_dict.items(): print(f"{k}: {v}\n") # Save training and test performance with open(config.CKPT_DIR + "/train_test_results.txt", 'w') as f: for k, v in results_dict.items(): f.write(f'{k}: {v}\n') ### Get embeddings from model # best_model_pth = config.MLM_MODEL_PATH + "/best_model" # model = AutoModel.from_pretrained(best_model_pth) # tokenizer = AutoTokenizer.from_pretrained(best_model_pth) # model.eval().to(device) # random_seq = "WPIQMVYSLGQHADYMQWFTIMPPPIEMIFVWHNCTQHDYSFRERAGEVDQARMKTEMAR" # inputs = tokenizer(random_seq, return_tensors='pt') # inputs = {k: v.to(device) for k, v in inputs.items()} # inputs = inputs['input_ids'] # print(inputs) # with torch.no_grad(): # outputs = model(inputs).last_hidden_state # print(outputs) # print(outputs.size())