File size: 4,890 Bytes
e8aad19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from difflib import Differ
import torch, re


class PllScore:
    def __init__(
        self, 
        language_model  # LanguageModel class instance
    ) -> None:

        self.tokenizer = language_model.initTokenizer()
        self.model = language_model.initModel()
        _ = self.model.eval()

        self.logSoftmax = torch.nn.LogSoftmax(dim=-1)

    def sentIsCorrect(
        self, 
        sent: str
    ) -> bool:

        # Mod
        is_correct = True

        # Check mark existence
        open_mark = sent.count("<")
        close_mark = sent.count(">")
        total_mark = open_mark + close_mark
        if (total_mark == 0) or (open_mark != close_mark):
            is_correct = False

        # Check existence of twin marks (ie: '<<' or '>>')
        if is_correct:
            left_twin = sent.count("<<")
            rigth_twin = sent.count(">>")
            if left_twin + rigth_twin > 0:
                is_correct = False

        if is_correct:
            # Check balanced symbols '<' and '>'
            stack = []
            for c in sent:
                if c == '<':
                    stack.append('<')
                elif c == '>':
                    if len(stack) == 0:
                        is_correct = False
                        break

                    if stack.pop() != "<":
                        is_correct = False
                        break

            if len(stack) > 0:
                is_correct = False

        if is_correct:
            for w in re.findall("\<.*?\>", sent):
                # Check empty interest words
                word = w.replace("<","").replace(">","").strip() 
                if not word:
                    is_correct = False
                    break
                
                # Check if there are any marks inside others (ie: <this is a <sentence>>)
                word = w.strip()[1:-1]  #Delete the first and last mark
                if '<' in word or '>' in word:
                    is_correct = False
                    break
        
        if is_correct:
            # Check that there is at least one uninteresting word. The next examples should not be allowed 
            # (ie: <this is a sent>, <this> <is a sent>)
            outside_words = re.sub("\<.*?\>", "", sent.replace("<", " < ").replace(">", " > "))
            outside_words = [w for w in outside_words.split() if w != ""]
            if not outside_words:
                is_correct = False


        return is_correct

    def compute(
        self, 
        sent: str
    ) -> float:

        assert(self.sentIsCorrect(sent)), f"Error: The sentence '{sent}' does not have the correct format!"

        outside_words = re.sub("\<.*?\>", "", sent.replace("<", " < ").replace(">", " > "))
        outside_words = [w for w in outside_words.split() if w != ""]
        all_words = [w.strip() for w in sent.replace("<"," ").replace(">"," ").split() if w != ""]

        tks_id_outside_words = self.tokenizer.encode(
            " ".join(outside_words), 
            add_special_tokens=False, 
            truncation=True
        )
        tks_id_all_words = self.tokenizer.encode(
            " ".join(all_words), 
            add_special_tokens=False,
            truncation=True
        )

        diff = [(tk[0], tk[2:]) for tk in Differ().compare(tks_id_outside_words, tks_id_all_words)]

        cls_tk_id = self.tokenizer.cls_token_id
        sep_tk_id = self.tokenizer.sep_token_id
        mask_tk_id = self.tokenizer.mask_token_id

        all_sent_masked = []
        all_tks_id_masked = []
        all_tks_position_masked = []

        for i in range(0, len(diff)):
            current_sent_masked = [cls_tk_id]
            add_sent = True
            for j, (mark, tk_id) in enumerate(diff):
                if j == i:
                    if mark == '+':
                        add_sent = False
                        break
                    else:
                        current_sent_masked.append(mask_tk_id)
                        all_tks_id_masked.append(int(tk_id))
                        all_tks_position_masked.append(i+1)
                else:
                    current_sent_masked.append(int(tk_id))

            if add_sent:
                current_sent_masked.append(sep_tk_id)
                all_sent_masked.append(current_sent_masked)
        
        inputs_ids = torch.tensor(all_sent_masked)
        attention_mask = torch.ones_like(inputs_ids)

        with torch.no_grad():
            out = self.model(inputs_ids, attention_mask)
            logits = out.logits
            outputs = self.logSoftmax(logits)

        pll_score = 0
        for out, tk_pos, tk_id in zip(outputs, all_tks_position_masked, all_tks_id_masked):
            probabilities = out[tk_pos]
            tk_prob = probabilities[tk_id]
            pll_score += tk_prob.item()

        return pll_score