File size: 3,527 Bytes
26d475a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import json
import pandas as pd
import datasets
import numpy as np
import evaluate
import torch
from transformers import AutoModel, DistilBertForSequenceClassification
from transformers.modeling_outputs import SequenceClassifierOutput
from typing import Optional

SEP_TOKEN = '[SEP]'
LABEL2NUM = {'entailment': 1, 'neutral': 0.5, 'contradiction': 0}

def format_dataset(arr):
    text = [el['sentence1'] + SEP_TOKEN + el['sentence2'] for el in arr]
    label = [LABEL2NUM[el['label']] for el in arr]
    new_df = pd.DataFrame({'text': text, 'label': label})
    return new_df.sample(frac=1, random_state=42).reset_index(drop=True)

# Load dataset
def load_dataset(path):
    train_array = []
    with open(path) as f:
        for line in f.readlines():
            if line:
                train_array.append(json.loads(line))
    df = format_dataset(train_array)
    # Split dataset into train and val
    df_train = df.iloc[512:, :]
    # We do not need much test data
    df_test = df.iloc[:512, :]
    print(df_train[:10])
    print(df_test[:10])

    factual_consistency_dataset = datasets.dataset_dict.DatasetDict()
    factual_consistency_dataset["train"] = datasets.dataset_dict.Dataset.from_pandas(
        df_train[["text", "label"]])
    factual_consistency_dataset["test"] = datasets.dataset_dict.Dataset.from_pandas(
        df_test[["text", "label"]])

    return factual_consistency_dataset


class ConsistentSentenceRegressor(DistilBertForSequenceClassification):

    def __init__(self, freeze_bert=True):
        base_model = AutoModel.from_pretrained(
            'line-corporation/line-distilbert-base-japanese')

        config = base_model.config
        config.problem_type = "regression"
        config.num_labels = 1
        super(ConsistentSentenceRegressor, self).__init__(config=config)

        self.distilbert = base_model

        # Replace the classifier with a single-neuron linear layer for regression
        self.classifier = torch.nn.Linear(config.dim, config.num_labels)

        if not freeze_bert:
            return

        for param in self.distilbert.parameters():
            param.requires_grad = False

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        print(input_ids.shape)
        outputs = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            labels=labels,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict
        )
        print(outputs.logits.shape)
        logits = outputs.logits.squeeze(-1)  # Remove the last dimension to match target tensor shape

        print(logits.shape)
        

        return logits


# Set up evaluation metridef get_metrics():

def get_metrics():
    metric = evaluate.load("mse")

    def compute_metrics(eval_pred):
        predictions, labels = eval_pred
        print(predictions.shape)
        print(labels.shape)
        return metric.compute(predictions=predictions, references=labels)

    return compute_metrics