|
import torch |
|
import torch.nn as nn |
|
from transformers import AutoModel |
|
|
|
class MultitaskCodeSimilarityModel(nn.Module): |
|
def __init__(self, config, tokenizer): |
|
super().__init__() |
|
self.config = config |
|
self.tokenizer = tokenizer |
|
self.encoder = AutoModel.from_config(config) |
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
self.decoder_embedding = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.decoder = nn.GRU( |
|
input_size=config.hidden_size, |
|
hidden_size=config.hidden_size, |
|
batch_first=True |
|
) |
|
self.explanation_head = nn.Linear(config.hidden_size, len(tokenizer)) |
|
|
|
def forward(self, input_ids, attention_mask, explanation_ids=None, explanation_mask=None): |
|
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
|
pooled = outputs.last_hidden_state[:, 0] |
|
logits = self.classifier(pooled) |
|
|
|
explanation_logits = None |
|
if explanation_ids is not None: |
|
decoder_input = self.decoder_embedding(pooled).unsqueeze(1).expand(-1, explanation_ids.size(1), -1) |
|
decoder_outputs, _ = self.decoder(decoder_input) |
|
explanation_logits = self.explanation_head(decoder_outputs) |
|
|
|
return logits, explanation_logits |