File size: 1,403 Bytes
1306f0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torch.nn as nn
from transformers import AutoModel

class MultitaskCodeSimilarityModel(nn.Module):
    def __init__(self, config, tokenizer):
        super().__init__()
        self.config = config
        self.tokenizer = tokenizer
        self.encoder = AutoModel.from_config(config)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        
        # For explanation generation
        self.decoder_embedding = nn.Linear(config.hidden_size, config.hidden_size)
        self.decoder = nn.GRU(
            input_size=config.hidden_size,
            hidden_size=config.hidden_size,
            batch_first=True
        )
        self.explanation_head = nn.Linear(config.hidden_size, len(tokenizer))
        
    def forward(self, input_ids, attention_mask, explanation_ids=None, explanation_mask=None):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        pooled = outputs.last_hidden_state[:, 0]
        logits = self.classifier(pooled)
        
        explanation_logits = None
        if explanation_ids is not None:
            decoder_input = self.decoder_embedding(pooled).unsqueeze(1).expand(-1, explanation_ids.size(1), -1)
            decoder_outputs, _ = self.decoder(decoder_input)
            explanation_logits = self.explanation_head(decoder_outputs)
            
        return logits, explanation_logits