import torch import torch.nn as nn from transformers import AutoModel class MultitaskCodeSimilarityModel(nn.Module): def __init__(self, config, tokenizer): super().__init__() self.config = config self.tokenizer = tokenizer self.encoder = AutoModel.from_config(config) self.classifier = nn.Linear(config.hidden_size, config.num_labels) # For explanation generation self.decoder_embedding = nn.Linear(config.hidden_size, config.hidden_size) self.decoder = nn.GRU( input_size=config.hidden_size, hidden_size=config.hidden_size, batch_first=True ) self.explanation_head = nn.Linear(config.hidden_size, len(tokenizer)) def forward(self, input_ids, attention_mask, explanation_ids=None, explanation_mask=None): outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) pooled = outputs.last_hidden_state[:, 0] logits = self.classifier(pooled) explanation_logits = None if explanation_ids is not None: decoder_input = self.decoder_embedding(pooled).unsqueeze(1).expand(-1, explanation_ids.size(1), -1) decoder_outputs, _ = self.decoder(decoder_input) explanation_logits = self.explanation_head(decoder_outputs) return logits, explanation_logits