import lightning.pytorch as pl from transformers import ( AdamW, AutoModel, AutoConfig, get_linear_schedule_with_warmup, ) from transformers.models.bert.modeling_bert import BertLMPredictionHead import torch from torch import nn from loss import CL_loss import pandas as pd class CL_model(pl.LightningModule): def __init__( self, n_batches=None, n_epochs=None, lr=None, mlm_weight=None, **kwargs ): super().__init__() ## Params self.n_batches = n_batches self.n_epochs = n_epochs self.lr = lr self.mlm_weight = mlm_weight # self.first_neg_idx = 0 self.config = AutoConfig.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") ## Encoder self.bert = AutoModel.from_pretrained( "emilyalsentzer/Bio_ClinicalBERT", return_dict=True ) # Unfreeze layers self.bert_layer_num = sum(1 for _ in self.bert.named_parameters()) self.num_unfreeze_layer = self.bert_layer_num self.ratio_unfreeze_layer = 0.0 if kwargs: for key, value in kwargs.items(): if key == "unfreeze" and isinstance(value, float): assert ( value >= 0.0 and value <= 1.0 ), "ValueError: value must be a ratio between 0.0 and 1.0" self.ratio_unfreeze_layer = value if self.ratio_unfreeze_layer > 0.0: self.num_unfreeze_layer = int( self.bert_layer_num * self.ratio_unfreeze_layer ) for param in list(self.bert.parameters())[: -self.num_unfreeze_layer]: param.requires_grad = False self.lm_head = BertLMPredictionHead(self.config) self.projector = nn.Linear(self.bert.config.hidden_size, 128) print("Model Initialized!") ## Losses self.cl_loss = CL_loss() self.mlm_loss = nn.CrossEntropyLoss() ## Logs self.train_loss, self.val_loss = [], [] self.train_cl_loss, self.val_cl_loss = [], [] self.train_mlm_loss, self.val_mlm_loss = [], [] self.training_step_outputs, self.validation_step_outputs = [], [] def forward(self, input_ids, attention_mask, mlm_ids, eval=False): # Contrastive unmasked = self.bert(input_ids=input_ids, attention_mask=attention_mask) cls = unmasked.pooler_output if eval is True: return cls output = self.projector(cls) # MLM masked = self.bert(input_ids=mlm_ids, attention_mask=attention_mask) pred = self.lm_head(masked.last_hidden_state) pred = pred.view(-1, self.config.vocab_size) return cls, output, pred def training_step(self, batch, batch_idx): tags = batch["tags"] input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] mlm_ids = batch["mlm_ids"] mlm_labels = batch["mlm_labels"].reshape(-1) cls, output, pred = self(input_ids, attention_mask, mlm_ids) loss_cl = self.cl_loss(output, tags) loss_mlm = self.mlm_loss(pred, mlm_labels) loss = loss_cl + self.mlm_weight * loss_mlm logs = {"loss": loss, "loss_cl": loss_cl, "loss_mlm": loss_mlm} self.training_step_outputs.append(logs) self.log("train_loss", loss, prog_bar=True, logger=True, sync_dist=True) return loss def on_train_epoch_end(self): avg_loss = ( torch.stack([x["loss"] for x in self.training_step_outputs]) .mean() .detach() .cpu() .numpy() ) self.train_loss.append(avg_loss) avg_cl_loss = ( torch.stack([x["loss_cl"] for x in self.training_step_outputs]) .mean() .detach() .cpu() .numpy() ) self.train_cl_loss.append(avg_cl_loss) avg_mlm_loss = ( torch.stack([x["loss_mlm"] for x in self.training_step_outputs]) .mean() .detach() .cpu() .numpy() ) self.train_mlm_loss.append(avg_mlm_loss) print( "train_epoch:", self.current_epoch, "avg_loss:", avg_loss, "avg_cl_loss:", avg_cl_loss, "avg_mlm_loss:", avg_mlm_loss, ) self.training_step_outputs.clear() def validation_step(self, batch, batch_idx): tags = batch["tags"] input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] mlm_ids = batch["mlm_ids"] mlm_labels = batch["mlm_labels"].reshape(-1) cls, output, pred = self(input_ids, attention_mask, mlm_ids) loss_cl = self.cl_loss(output, tags) loss_mlm = self.mlm_loss(pred, mlm_labels) loss = loss_cl + self.mlm_weight * loss_mlm logs = {"loss": loss, "loss_cl": loss_cl, "loss_mlm": loss_mlm} self.validation_step_outputs.append(logs) self.log("validation_loss", loss, prog_bar=True, logger=True, sync_dist=True) return loss def on_validation_epoch_end(self): avg_loss = ( torch.stack([x["loss"] for x in self.validation_step_outputs]) .mean() .detach() .cpu() .numpy() ) self.val_loss.append(avg_loss) avg_cl_loss = ( torch.stack([x["loss_cl"] for x in self.validation_step_outputs]) .mean() .detach() .cpu() .numpy() ) self.val_cl_loss.append(avg_cl_loss) avg_mlm_loss = ( torch.stack([x["loss_mlm"] for x in self.validation_step_outputs]) .mean() .detach() .cpu() .numpy() ) self.val_mlm_loss.append(avg_mlm_loss) print( "val_epoch:", self.current_epoch, "avg_loss:", avg_loss, "avg_cl_loss:", avg_cl_loss, "avg_mlm_loss:", avg_mlm_loss, ) self.validation_step_outputs.clear() def configure_optimizers(self): # Optimizer self.trainable_params = [ param for param in self.parameters() if param.requires_grad ] optimizer = AdamW(self.trainable_params, lr=self.lr) # Scheduler warmup_steps = self.n_batches // 3 total_steps = self.n_batches * self.n_epochs - warmup_steps scheduler = get_linear_schedule_with_warmup( optimizer, warmup_steps, total_steps ) return [optimizer], [scheduler]