Aidan Phillips commited on
Commit
7159c31
·
1 Parent(s): f97ee3d

bug fix, threshold tweak

Browse files
Files changed (2) hide show
  1. categories/fluency.py +6 -10
  2. scorer.ipynb +12 -31
categories/fluency.py CHANGED
@@ -61,7 +61,6 @@ def pseudo_perplexity(text, threshold=20, max_len=128):
61
  word_groups.append(current_group)
62
 
63
  loss_values = []
64
- tok_loss = []
65
  for group in word_groups:
66
  if group[0] == 0 or group[-1] == len(input_ids) - 1:
67
  continue # skip [CLS] and [SEP]
@@ -80,14 +79,11 @@ def pseudo_perplexity(text, threshold=20, max_len=128):
80
  true_token_id = input_ids[i].item()
81
  prob = probs[true_token_id].item()
82
  log_probs.append(np.log(prob + 1e-12))
83
- tok_loss.append(-np.log(prob + 1e-12))
84
 
85
  word_loss = -np.sum(log_probs) / len(log_probs)
86
  word = tokenizer.decode(input_ids[group[0]])
87
  word_loss -= 0.6 * __get_word_pr_score(word)
88
  loss_values.append(word_loss)
89
-
90
- # print(loss_values)
91
 
92
  errors = []
93
  for i, l in enumerate(loss_values):
@@ -99,12 +95,10 @@ def pseudo_perplexity(text, threshold=20, max_len=128):
99
  "message": f"Perplexity {l} over threshold {threshold}"
100
  })
101
 
102
- # print(tok_loss)
103
- s_ppl = np.mean(tok_loss)
104
- # print(s_ppl)
105
 
106
  res = {
107
- "score": __fluency_score_from_ppl(s_ppl),
108
  "errors": errors
109
  }
110
 
@@ -129,7 +123,6 @@ def grammar_errors(text) -> tuple[int, list[str]]:
129
  """
130
 
131
  matches = tool.check(text)
132
- grammar_score = len(matches)/len(text.split())
133
 
134
  r = []
135
  for match in matches:
@@ -150,7 +143,10 @@ def grammar_errors(text) -> tuple[int, list[str]]:
150
  r.append({"start": start, "end": end, "message": match.message})
151
 
152
  struct_err = __check_structural_grammar(text)
153
- r.extend(struct_err)
 
 
 
154
 
155
  res = {
156
  "score": __grammar_score_from_prob(grammar_score),
 
61
  word_groups.append(current_group)
62
 
63
  loss_values = []
 
64
  for group in word_groups:
65
  if group[0] == 0 or group[-1] == len(input_ids) - 1:
66
  continue # skip [CLS] and [SEP]
 
79
  true_token_id = input_ids[i].item()
80
  prob = probs[true_token_id].item()
81
  log_probs.append(np.log(prob + 1e-12))
 
82
 
83
  word_loss = -np.sum(log_probs) / len(log_probs)
84
  word = tokenizer.decode(input_ids[group[0]])
85
  word_loss -= 0.6 * __get_word_pr_score(word)
86
  loss_values.append(word_loss)
 
 
87
 
88
  errors = []
89
  for i, l in enumerate(loss_values):
 
95
  "message": f"Perplexity {l} over threshold {threshold}"
96
  })
97
 
98
+ error_rate = len(errors) / len(loss_values)
 
 
99
 
100
  res = {
101
+ "score": __grammar_score_from_prob(error_rate),
102
  "errors": errors
103
  }
104
 
 
123
  """
124
 
125
  matches = tool.check(text)
 
126
 
127
  r = []
128
  for match in matches:
 
143
  r.append({"start": start, "end": end, "message": match.message})
144
 
145
  struct_err = __check_structural_grammar(text)
146
+ for e in struct_err:
147
+ r.append(e)
148
+
149
+ grammar_score = len(r) / len(text.split())
150
 
151
  res = {
152
  "score": __grammar_score_from_prob(grammar_score),
scorer.ipynb CHANGED
@@ -11,32 +11,14 @@
11
  },
12
  {
13
  "cell_type": "code",
14
- "execution_count": 2,
15
  "metadata": {},
16
  "outputs": [
17
  {
18
  "name": "stdout",
19
  "output_type": "stream",
20
  "text": [
21
- "Sentence: The cat sat the quickly up apples banana.\n",
22
- "tensor([ 101, 10117, 41163, 20694, 10105, 23590, 10741, 72894, 11268, 99304,\n",
23
- " 10219, 119, 102])\n",
24
- "tensor([[ 0, 0],\n",
25
- " [ 0, 3],\n",
26
- " [ 4, 7],\n",
27
- " [ 8, 11],\n",
28
- " [12, 15],\n",
29
- " [16, 23],\n",
30
- " [24, 26],\n",
31
- " [27, 30],\n",
32
- " [30, 33],\n",
33
- " [34, 38],\n",
34
- " [38, 40],\n",
35
- " [40, 41],\n",
36
- " [ 0, 0]])\n",
37
- "[np.float64(0.00905743383887514), np.float64(1.1257066968185931), np.float64(4.8056646935577145), np.float64(4.473408069089179), np.float64(4.732453441503642), np.float64(3.028744414819041), np.float64(5.1115574262487735), np.float64(-0.6523823890571343)]\n",
38
- "[np.float64(1.7636628003080927), np.float64(6.955413759407024), np.float64(10.828562153345375), np.float64(6.228013435558396), np.float64(10.258657658689351), np.float64(6.635744767229443), np.float64(11.163667119285972), np.float64(10.499412826924114), np.float64(11.96113847381264), np.float64(10.010973250156082), np.float64(2.470404176100153)]\n",
39
- "0.5208035409471965\n"
40
  ]
41
  }
42
  ],
@@ -49,12 +31,12 @@
49
  "print(\"Sentence:\", s) # Print the input sentence\n",
50
  "\n",
51
  "err = grammar_errors(s) # Call the function to execute the grammar error checking\n",
52
- "flu = pseudo_perplexity(s, threshold=2.5) # Call the function to execute the fluency checking"
53
  ]
54
  },
55
  {
56
  "cell_type": "code",
57
- "execution_count": 3,
58
  "metadata": {},
59
  "outputs": [
60
  {
@@ -62,11 +44,10 @@
62
  "output_type": "stream",
63
  "text": [
64
  "An apostrophe may be missing.: apples banana.\n",
65
- "Perplexity 4.8056646935577145 over threshold 2.5: sat\n",
66
- "Perplexity 4.473408069089179 over threshold 2.5: the\n",
67
- "Perplexity 4.732453441503642 over threshold 2.5: quickly\n",
68
- "Perplexity 3.028744414819041 over threshold 2.5: up\n",
69
- "Perplexity 5.1115574262487735 over threshold 2.5: apples\n"
70
  ]
71
  }
72
  ],
@@ -80,20 +61,20 @@
80
  },
81
  {
82
  "cell_type": "code",
83
- "execution_count": null,
84
  "metadata": {},
85
  "outputs": [
86
  {
87
  "name": "stdout",
88
  "output_type": "stream",
89
  "text": [
90
- "87.5 99.71\n",
91
- "Fluency Score: 92.384\n"
92
  ]
93
  }
94
  ],
95
  "source": [
96
- "fluency_score = 0.7 * err[\"score\"] + 0.3 * flu[\"score\"] # Calculate the fluency score\n",
97
  "print(err[\"score\"], flu[\"score\"]) # Print the individual scores\n",
98
  "print(\"Fluency Score:\", fluency_score) # Print the fluency score"
99
  ]
 
11
  },
12
  {
13
  "cell_type": "code",
14
+ "execution_count": 20,
15
  "metadata": {},
16
  "outputs": [
17
  {
18
  "name": "stdout",
19
  "output_type": "stream",
20
  "text": [
21
+ "Sentence: The cat sat the quickly up apples banana.\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  ]
23
  }
24
  ],
 
31
  "print(\"Sentence:\", s) # Print the input sentence\n",
32
  "\n",
33
  "err = grammar_errors(s) # Call the function to execute the grammar error checking\n",
34
+ "flu = pseudo_perplexity(s, threshold=3.5) # Call the function to execute the fluency checking"
35
  ]
36
  },
37
  {
38
  "cell_type": "code",
39
+ "execution_count": 21,
40
  "metadata": {},
41
  "outputs": [
42
  {
 
44
  "output_type": "stream",
45
  "text": [
46
  "An apostrophe may be missing.: apples banana.\n",
47
+ "Perplexity 4.8056646935577145 over threshold 3.5: sat\n",
48
+ "Perplexity 4.473408069089179 over threshold 3.5: the\n",
49
+ "Perplexity 4.732453441503642 over threshold 3.5: quickly\n",
50
+ "Perplexity 5.1115574262487735 over threshold 3.5: apples\n"
 
51
  ]
52
  }
53
  ],
 
61
  },
62
  {
63
  "cell_type": "code",
64
+ "execution_count": 22,
65
  "metadata": {},
66
  "outputs": [
67
  {
68
  "name": "stdout",
69
  "output_type": "stream",
70
  "text": [
71
+ "87.5 50.0\n",
72
+ "Fluency Score: 68.75\n"
73
  ]
74
  }
75
  ],
76
  "source": [
77
+ "fluency_score = 0.5 * err[\"score\"] + 0.5 * flu[\"score\"] # Calculate the fluency score\n",
78
  "print(err[\"score\"], flu[\"score\"]) # Print the individual scores\n",
79
  "print(\"Fluency Score:\", fluency_score) # Print the fluency score"
80
  ]