Aidan Phillips
commited on
Commit
·
7159c31
1
Parent(s):
f97ee3d
bug fix, threshold tweak
Browse files- categories/fluency.py +6 -10
- scorer.ipynb +12 -31
categories/fluency.py
CHANGED
@@ -61,7 +61,6 @@ def pseudo_perplexity(text, threshold=20, max_len=128):
|
|
61 |
word_groups.append(current_group)
|
62 |
|
63 |
loss_values = []
|
64 |
-
tok_loss = []
|
65 |
for group in word_groups:
|
66 |
if group[0] == 0 or group[-1] == len(input_ids) - 1:
|
67 |
continue # skip [CLS] and [SEP]
|
@@ -80,14 +79,11 @@ def pseudo_perplexity(text, threshold=20, max_len=128):
|
|
80 |
true_token_id = input_ids[i].item()
|
81 |
prob = probs[true_token_id].item()
|
82 |
log_probs.append(np.log(prob + 1e-12))
|
83 |
-
tok_loss.append(-np.log(prob + 1e-12))
|
84 |
|
85 |
word_loss = -np.sum(log_probs) / len(log_probs)
|
86 |
word = tokenizer.decode(input_ids[group[0]])
|
87 |
word_loss -= 0.6 * __get_word_pr_score(word)
|
88 |
loss_values.append(word_loss)
|
89 |
-
|
90 |
-
# print(loss_values)
|
91 |
|
92 |
errors = []
|
93 |
for i, l in enumerate(loss_values):
|
@@ -99,12 +95,10 @@ def pseudo_perplexity(text, threshold=20, max_len=128):
|
|
99 |
"message": f"Perplexity {l} over threshold {threshold}"
|
100 |
})
|
101 |
|
102 |
-
|
103 |
-
s_ppl = np.mean(tok_loss)
|
104 |
-
# print(s_ppl)
|
105 |
|
106 |
res = {
|
107 |
-
"score":
|
108 |
"errors": errors
|
109 |
}
|
110 |
|
@@ -129,7 +123,6 @@ def grammar_errors(text) -> tuple[int, list[str]]:
|
|
129 |
"""
|
130 |
|
131 |
matches = tool.check(text)
|
132 |
-
grammar_score = len(matches)/len(text.split())
|
133 |
|
134 |
r = []
|
135 |
for match in matches:
|
@@ -150,7 +143,10 @@ def grammar_errors(text) -> tuple[int, list[str]]:
|
|
150 |
r.append({"start": start, "end": end, "message": match.message})
|
151 |
|
152 |
struct_err = __check_structural_grammar(text)
|
153 |
-
|
|
|
|
|
|
|
154 |
|
155 |
res = {
|
156 |
"score": __grammar_score_from_prob(grammar_score),
|
|
|
61 |
word_groups.append(current_group)
|
62 |
|
63 |
loss_values = []
|
|
|
64 |
for group in word_groups:
|
65 |
if group[0] == 0 or group[-1] == len(input_ids) - 1:
|
66 |
continue # skip [CLS] and [SEP]
|
|
|
79 |
true_token_id = input_ids[i].item()
|
80 |
prob = probs[true_token_id].item()
|
81 |
log_probs.append(np.log(prob + 1e-12))
|
|
|
82 |
|
83 |
word_loss = -np.sum(log_probs) / len(log_probs)
|
84 |
word = tokenizer.decode(input_ids[group[0]])
|
85 |
word_loss -= 0.6 * __get_word_pr_score(word)
|
86 |
loss_values.append(word_loss)
|
|
|
|
|
87 |
|
88 |
errors = []
|
89 |
for i, l in enumerate(loss_values):
|
|
|
95 |
"message": f"Perplexity {l} over threshold {threshold}"
|
96 |
})
|
97 |
|
98 |
+
error_rate = len(errors) / len(loss_values)
|
|
|
|
|
99 |
|
100 |
res = {
|
101 |
+
"score": __grammar_score_from_prob(error_rate),
|
102 |
"errors": errors
|
103 |
}
|
104 |
|
|
|
123 |
"""
|
124 |
|
125 |
matches = tool.check(text)
|
|
|
126 |
|
127 |
r = []
|
128 |
for match in matches:
|
|
|
143 |
r.append({"start": start, "end": end, "message": match.message})
|
144 |
|
145 |
struct_err = __check_structural_grammar(text)
|
146 |
+
for e in struct_err:
|
147 |
+
r.append(e)
|
148 |
+
|
149 |
+
grammar_score = len(r) / len(text.split())
|
150 |
|
151 |
res = {
|
152 |
"score": __grammar_score_from_prob(grammar_score),
|
scorer.ipynb
CHANGED
@@ -11,32 +11,14 @@
|
|
11 |
},
|
12 |
{
|
13 |
"cell_type": "code",
|
14 |
-
"execution_count":
|
15 |
"metadata": {},
|
16 |
"outputs": [
|
17 |
{
|
18 |
"name": "stdout",
|
19 |
"output_type": "stream",
|
20 |
"text": [
|
21 |
-
"Sentence: The cat sat the quickly up apples banana.\n"
|
22 |
-
"tensor([ 101, 10117, 41163, 20694, 10105, 23590, 10741, 72894, 11268, 99304,\n",
|
23 |
-
" 10219, 119, 102])\n",
|
24 |
-
"tensor([[ 0, 0],\n",
|
25 |
-
" [ 0, 3],\n",
|
26 |
-
" [ 4, 7],\n",
|
27 |
-
" [ 8, 11],\n",
|
28 |
-
" [12, 15],\n",
|
29 |
-
" [16, 23],\n",
|
30 |
-
" [24, 26],\n",
|
31 |
-
" [27, 30],\n",
|
32 |
-
" [30, 33],\n",
|
33 |
-
" [34, 38],\n",
|
34 |
-
" [38, 40],\n",
|
35 |
-
" [40, 41],\n",
|
36 |
-
" [ 0, 0]])\n",
|
37 |
-
"[np.float64(0.00905743383887514), np.float64(1.1257066968185931), np.float64(4.8056646935577145), np.float64(4.473408069089179), np.float64(4.732453441503642), np.float64(3.028744414819041), np.float64(5.1115574262487735), np.float64(-0.6523823890571343)]\n",
|
38 |
-
"[np.float64(1.7636628003080927), np.float64(6.955413759407024), np.float64(10.828562153345375), np.float64(6.228013435558396), np.float64(10.258657658689351), np.float64(6.635744767229443), np.float64(11.163667119285972), np.float64(10.499412826924114), np.float64(11.96113847381264), np.float64(10.010973250156082), np.float64(2.470404176100153)]\n",
|
39 |
-
"0.5208035409471965\n"
|
40 |
]
|
41 |
}
|
42 |
],
|
@@ -49,12 +31,12 @@
|
|
49 |
"print(\"Sentence:\", s) # Print the input sentence\n",
|
50 |
"\n",
|
51 |
"err = grammar_errors(s) # Call the function to execute the grammar error checking\n",
|
52 |
-
"flu = pseudo_perplexity(s, threshold=
|
53 |
]
|
54 |
},
|
55 |
{
|
56 |
"cell_type": "code",
|
57 |
-
"execution_count":
|
58 |
"metadata": {},
|
59 |
"outputs": [
|
60 |
{
|
@@ -62,11 +44,10 @@
|
|
62 |
"output_type": "stream",
|
63 |
"text": [
|
64 |
"An apostrophe may be missing.: apples banana.\n",
|
65 |
-
"Perplexity 4.8056646935577145 over threshold
|
66 |
-
"Perplexity 4.473408069089179 over threshold
|
67 |
-
"Perplexity 4.732453441503642 over threshold
|
68 |
-
"Perplexity
|
69 |
-
"Perplexity 5.1115574262487735 over threshold 2.5: apples\n"
|
70 |
]
|
71 |
}
|
72 |
],
|
@@ -80,20 +61,20 @@
|
|
80 |
},
|
81 |
{
|
82 |
"cell_type": "code",
|
83 |
-
"execution_count":
|
84 |
"metadata": {},
|
85 |
"outputs": [
|
86 |
{
|
87 |
"name": "stdout",
|
88 |
"output_type": "stream",
|
89 |
"text": [
|
90 |
-
"87.5
|
91 |
-
"Fluency Score:
|
92 |
]
|
93 |
}
|
94 |
],
|
95 |
"source": [
|
96 |
-
"fluency_score = 0.
|
97 |
"print(err[\"score\"], flu[\"score\"]) # Print the individual scores\n",
|
98 |
"print(\"Fluency Score:\", fluency_score) # Print the fluency score"
|
99 |
]
|
|
|
11 |
},
|
12 |
{
|
13 |
"cell_type": "code",
|
14 |
+
"execution_count": 20,
|
15 |
"metadata": {},
|
16 |
"outputs": [
|
17 |
{
|
18 |
"name": "stdout",
|
19 |
"output_type": "stream",
|
20 |
"text": [
|
21 |
+
"Sentence: The cat sat the quickly up apples banana.\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
]
|
23 |
}
|
24 |
],
|
|
|
31 |
"print(\"Sentence:\", s) # Print the input sentence\n",
|
32 |
"\n",
|
33 |
"err = grammar_errors(s) # Call the function to execute the grammar error checking\n",
|
34 |
+
"flu = pseudo_perplexity(s, threshold=3.5) # Call the function to execute the fluency checking"
|
35 |
]
|
36 |
},
|
37 |
{
|
38 |
"cell_type": "code",
|
39 |
+
"execution_count": 21,
|
40 |
"metadata": {},
|
41 |
"outputs": [
|
42 |
{
|
|
|
44 |
"output_type": "stream",
|
45 |
"text": [
|
46 |
"An apostrophe may be missing.: apples banana.\n",
|
47 |
+
"Perplexity 4.8056646935577145 over threshold 3.5: sat\n",
|
48 |
+
"Perplexity 4.473408069089179 over threshold 3.5: the\n",
|
49 |
+
"Perplexity 4.732453441503642 over threshold 3.5: quickly\n",
|
50 |
+
"Perplexity 5.1115574262487735 over threshold 3.5: apples\n"
|
|
|
51 |
]
|
52 |
}
|
53 |
],
|
|
|
61 |
},
|
62 |
{
|
63 |
"cell_type": "code",
|
64 |
+
"execution_count": 22,
|
65 |
"metadata": {},
|
66 |
"outputs": [
|
67 |
{
|
68 |
"name": "stdout",
|
69 |
"output_type": "stream",
|
70 |
"text": [
|
71 |
+
"87.5 50.0\n",
|
72 |
+
"Fluency Score: 68.75\n"
|
73 |
]
|
74 |
}
|
75 |
],
|
76 |
"source": [
|
77 |
+
"fluency_score = 0.5 * err[\"score\"] + 0.5 * flu[\"score\"] # Calculate the fluency score\n",
|
78 |
"print(err[\"score\"], flu[\"score\"]) # Print the individual scores\n",
|
79 |
"print(\"Fluency Score:\", fluency_score) # Print the fluency score"
|
80 |
]
|