hylee commited on
Commit
3f1e400
·
1 Parent(s): 8e1280d

integrate focusing question and math terms

Browse files
Files changed (2) hide show
  1. handler.py +81 -4
  2. utils.py +282 -0
handler.py CHANGED
@@ -19,6 +19,7 @@ transformers.logging.set_verbosity_debug()
19
  UPTAKE_MODEL = 'ddemszky/uptake-model'
20
  REASONING_MODEL = 'ddemszky/student-reasoning'
21
  QUESTION_MODEL = 'ddemszky/question-detection'
 
22
 
23
 
24
  class Utterance:
@@ -36,11 +37,14 @@ class Utterance:
36
  self.timestamp = [starttime, endtime]
37
  self.unit_measure = None
38
  self.aggregate_unit_measure = endtime
 
 
39
 
40
  # moments
41
  self.uptake = None
42
  self.reasoning = None
43
  self.question = None
 
44
 
45
  def get_clean_text(self, remove_punct=False):
46
  if remove_punct:
@@ -60,6 +64,9 @@ class Utterance:
60
  'uptake': self.uptake,
61
  'reasoning': self.reasoning,
62
  'question': self.question,
 
 
 
63
  **self.props
64
  }
65
 
@@ -69,10 +76,12 @@ class Utterance:
69
  'text': self.text,
70
  'role': self.role,
71
  'timestamp': self.timestamp,
72
- 'moments': {'reasoning': True if self.reasoning else False, 'questioning': True if self.question else False, 'uptake': True if self.uptake else False},
73
  'unitMeasure': self.unit_measure,
74
  'aggregateUnitMeasure': self.aggregate_unit_measure,
75
- 'wordCount': self.word_count
 
 
76
  }
77
 
78
  def __repr__(self):
@@ -311,6 +320,67 @@ class UptakeModel:
311
  return_pooler_output=False)
312
  return output
313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
 
315
  class EndpointHandler():
316
  def __init__(self, path="."):
@@ -358,14 +428,21 @@ class EndpointHandler():
358
  question_model = QuestionModel(
359
  self.device, self.tokenizer, self.input_builder)
360
  question_model.run_inference(transcript)
 
 
 
 
 
 
 
 
361
  transcript.update_utterance_roles(uptake_speaker)
362
  transcript.calculate_aggregate_word_count()
363
  return_dict = {'talkDistribution': None, 'talkLength': None, 'talkMoments': None, 'commonTopWords': None, 'uptakeTopWords': None}
364
  talk_dist, talk_len = transcript.get_talk_distribution_and_length(uptake_speaker)
365
  return_dict['talkDistribution'] = talk_dist
366
  return_dict['talkLength'] = talk_len
367
- talk_timeline = transcript.get_talk_timeline()
368
- talk_moments = talk_timeline
369
  return_dict['talkMoments'] = talk_moments
370
  word_cloud, uptake_word_cloud = transcript.get_word_cloud_dicts()
371
  return_dict['commonTopWords'] = word_cloud
 
19
  UPTAKE_MODEL = 'ddemszky/uptake-model'
20
  REASONING_MODEL = 'ddemszky/student-reasoning'
21
  QUESTION_MODEL = 'ddemszky/question-detection'
22
+ FOCUSING_QUESTION_MODEL = 'ddemszky/focusing-questions'
23
 
24
 
25
  class Utterance:
 
37
  self.timestamp = [starttime, endtime]
38
  self.unit_measure = None
39
  self.aggregate_unit_measure = endtime
40
+ self.num_math_terms = None
41
+ self.math_terms = None
42
 
43
  # moments
44
  self.uptake = None
45
  self.reasoning = None
46
  self.question = None
47
+ self.focusing_question = None
48
 
49
  def get_clean_text(self, remove_punct=False):
50
  if remove_punct:
 
64
  'uptake': self.uptake,
65
  'reasoning': self.reasoning,
66
  'question': self.question,
67
+ 'focusingQuestion': self.focusing_question,
68
+ 'numMathTerms': self.num_math_terms,
69
+ 'mathTerms': self.math_terms,
70
  **self.props
71
  }
72
 
 
76
  'text': self.text,
77
  'role': self.role,
78
  'timestamp': self.timestamp,
79
+ 'moments': {'reasoning': True if self.reasoning else False, 'questioning': True if self.question else False, 'uptake': True if self.uptake else False, 'focusingQuestion': True if self.focusing_question else False},
80
  'unitMeasure': self.unit_measure,
81
  'aggregateUnitMeasure': self.aggregate_unit_measure,
82
+ 'wordCount': self.word_count,
83
+ 'numMathTerms': self.num_math_terms,
84
+ 'mathTerms': self.math_terms
85
  }
86
 
87
  def __repr__(self):
 
320
  return_pooler_output=False)
321
  return output
322
 
323
+ class FocusingQuestionModel:
324
+ def __init__(self, device, tokenizer, input_builder, max_length=128, path=FOCUSING_QUESTION_MODEL):
325
+ print("Loading models...")
326
+ self.device = device
327
+ self.tokenizer = tokenizer
328
+ self.input_builder = input_builder
329
+ self.model = BertForSequenceClassification.from_pretrained(path)
330
+ self.model.to(self.device)
331
+ self.max_length = max_length
332
+
333
+ def run_inference(self, transcript, min_focusing_words=0, uptake_speaker=None):
334
+ self.model.eval()
335
+ with torch.no_grad():
336
+ for i, utt in enumerate(transcript.utterances):
337
+ if utt.speaker != uptake_speaker or uptake_speaker is None:
338
+ utt.focusing_question = None
339
+ continue
340
+ if utt.get_num_words() < min_focusing_words:
341
+ utt.focusing_question = None
342
+ continue
343
+ instance = self.input_builder.build_inputs([], utt.text, max_length=self.max_length, input_str=True)
344
+ output = self.get_prediction(instance)
345
+ utt.focusing_question = np.argmax(output["logits"][0].tolist())
346
+
347
+ def get_prediction(self, instance):
348
+ instance["attention_mask"] = [[1] * len(instance["input_ids"])]
349
+ for key in ["input_ids", "token_type_ids", "attention_mask"]:
350
+ instance[key] = torch.tensor(
351
+ instance[key]).unsqueeze(0) # Batch size = 1
352
+ instance[key].to(self.device)
353
+
354
+ output = self.model(input_ids=instance["input_ids"],
355
+ attention_mask=instance["attention_mask"],
356
+ token_type_ids=instance["token_type_ids"])
357
+ return output
358
+
359
+ def load_math_terms():
360
+ math_terms = []
361
+ math_terms_dict = {}
362
+ for term in MATH_WORDS:
363
+ if term in MATH_PREFIXES:
364
+ math_terms_dict[f"(^|[^a-zA-Z]){term}(s|es)?([^a-zA-Z]|$)"] = term
365
+ math_terms.append(f"(^|[^a-zA-Z]){term}(s|es)?([^a-zA-Z]|$)")
366
+ else:
367
+ math_terms_dict[f"(^|[^a-zA-Z]){term}([^a-zA-Z]|$)"] = term
368
+ math_terms.append(f"(^|[^a-zA-Z]){term}([^a-zA-Z]|$)")
369
+ return math_terms, math_terms_dict
370
+
371
+ def run_math_density(transcript):
372
+ math_terms, math_terms_dict = load_math_terms()
373
+ for i, utt in enumerate(transcript.utterances):
374
+ found_math_terms = set()
375
+ text = utt.get_clean_text(remove_punct=False)
376
+ num_math_terms = 0
377
+ for term in math_terms:
378
+ count = len(re.findall(term, text))
379
+ if count > 0:
380
+ found_math_terms.add(math_terms_dict[term])
381
+ num_math_terms += count
382
+ utt.num_math_terms = num_math_terms
383
+ utt.math_terms = list(found_math_terms)
384
 
385
  class EndpointHandler():
386
  def __init__(self, path="."):
 
428
  question_model = QuestionModel(
429
  self.device, self.tokenizer, self.input_builder)
430
  question_model.run_inference(transcript)
431
+
432
+ # Focusing Question
433
+ focusing_question_model = FocusingQuestionModel(
434
+ self.device, self.tokenizer, self.input_builder)
435
+ focusing_question_model.run_inference(transcript, uptake_speaker=uptake_speaker)
436
+
437
+ run_math_density(transcript)
438
+
439
  transcript.update_utterance_roles(uptake_speaker)
440
  transcript.calculate_aggregate_word_count()
441
  return_dict = {'talkDistribution': None, 'talkLength': None, 'talkMoments': None, 'commonTopWords': None, 'uptakeTopWords': None}
442
  talk_dist, talk_len = transcript.get_talk_distribution_and_length(uptake_speaker)
443
  return_dict['talkDistribution'] = talk_dist
444
  return_dict['talkLength'] = talk_len
445
+ talk_moments = transcript.get_talk_timeline()
 
446
  return_dict['talkMoments'] = talk_moments
447
  word_cloud, uptake_word_cloud = transcript.get_word_cloud_dicts()
448
  return_dict['commonTopWords'] = word_cloud
utils.py CHANGED
@@ -13,6 +13,288 @@ punct_chars.sort()
13
  punctuation = ''.join(punct_chars)
14
  replace = re.compile('[%s]' % re.escape(punctuation))
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def get_num_words(text):
17
  if not isinstance(text, str):
18
  print("%s is not a string" % text)
 
13
  punctuation = ''.join(punct_chars)
14
  replace = re.compile('[%s]' % re.escape(punctuation))
15
 
16
+ MATH_PREFIXES = [
17
+ "sum",
18
+ "arc",
19
+ "mass",
20
+ "digit",
21
+ "graph",
22
+ "liter",
23
+ "gram",
24
+ "add",
25
+ "angle",
26
+ "scale",
27
+ "data",
28
+ "array",
29
+ "ruler",
30
+ "meter",
31
+ "total",
32
+ "unit",
33
+ "prism",
34
+ "median",
35
+ "ratio",
36
+ "area",
37
+ ]
38
+
39
+ MATH_WORDS = [
40
+ "absolute value",
41
+ "area",
42
+ "average",
43
+ "base of",
44
+ "box plot",
45
+ "categorical",
46
+ "coefficient",
47
+ "common factor",
48
+ "common multiple",
49
+ "compose",
50
+ "coordinate",
51
+ "cubed",
52
+ "decompose",
53
+ "dependent variable",
54
+ "distribution",
55
+ "dot plot",
56
+ "double number line diagram",
57
+ "equivalent",
58
+ "equivalent expression",
59
+ "ratio",
60
+ "exponent",
61
+ "frequency",
62
+ "greatest common factor",
63
+ "gcd",
64
+ "height of",
65
+ "histogram",
66
+ "independent variable",
67
+ "interquartile range",
68
+ "iqr",
69
+ "least common multiple",
70
+ "long division",
71
+ "mean absolute deviation",
72
+ "median",
73
+ "negative number",
74
+ "opposite vertex",
75
+ "parallelogram",
76
+ "percent",
77
+ "polygon",
78
+ "polyhedron",
79
+ "positive number",
80
+ "prism",
81
+ "pyramid",
82
+ "quadrant",
83
+ "quadrilateral",
84
+ "quartile",
85
+ "rational number",
86
+ "reciprocal",
87
+ "equality",
88
+ "inequality",
89
+ "squared",
90
+ "statistic",
91
+ "surface area",
92
+ "identity property",
93
+ "addend",
94
+ "unit",
95
+ "number sentence",
96
+ "make ten",
97
+ "take from ten",
98
+ "number bond",
99
+ "total",
100
+ "estimate",
101
+ "hashmark",
102
+ "meter",
103
+ "number line",
104
+ "ruler",
105
+ "centimeter",
106
+ "base ten",
107
+ "expanded form",
108
+ "hundred",
109
+ "thousand",
110
+ "place value",
111
+ "number disk",
112
+ "standard form",
113
+ "unit form",
114
+ "word form",
115
+ "tens place",
116
+ "algorithm",
117
+ "equation",
118
+ "simplif",
119
+ "addition",
120
+ "subtract",
121
+ "array",
122
+ "even number",
123
+ "odd number",
124
+ "repeated addition",
125
+ "tessellat",
126
+ "whole number",
127
+ "number path",
128
+ "rectangle",
129
+ "square",
130
+ "bar graph",
131
+ "data",
132
+ "degree",
133
+ "line plot",
134
+ "picture graph",
135
+ "scale",
136
+ "survey",
137
+ "thermometer",
138
+ "estimat",
139
+ "tape diagram",
140
+ "value",
141
+ "analog",
142
+ "angle",
143
+ "parallel",
144
+ "partition",
145
+ "pentagon",
146
+ "right angle",
147
+ "cube",
148
+ "digital",
149
+ "quarter of",
150
+ "tangram",
151
+ "circle",
152
+ "hexagon",
153
+ "half circle",
154
+ "half-circle",
155
+ "quarter circle",
156
+ "quarter-circle",
157
+ "semicircle",
158
+ "semi-circle",
159
+ "rectang",
160
+ "rhombus",
161
+ "trapezoid",
162
+ "triangle",
163
+ "commutative",
164
+ "equal group",
165
+ "distributive",
166
+ "divide",
167
+ "division",
168
+ "multipl",
169
+ "parentheses",
170
+ "quotient",
171
+ "rotate",
172
+ "unknown",
173
+ "add",
174
+ "capacity",
175
+ "continuous",
176
+ "endpoint",
177
+ "gram",
178
+ "interval",
179
+ "kilogram",
180
+ "volume",
181
+ "liter",
182
+ "milliliter",
183
+ "approximate",
184
+ "area model",
185
+ "square unit",
186
+ "unit square",
187
+ "geometr",
188
+ "equivalent fraction",
189
+ "fraction form",
190
+ "fractional unit",
191
+ "unit fraction",
192
+ "unit interval",
193
+ "measur",
194
+ "graph",
195
+ "scaled graph",
196
+ "diagonal",
197
+ "perimeter",
198
+ "regular polygon",
199
+ "tessellate",
200
+ "tetromino",
201
+ "heptagon",
202
+ "octagon",
203
+ "digit",
204
+ "expression",
205
+ "sum",
206
+ "kilometer",
207
+ "mass",
208
+ "mixed unit",
209
+ "length",
210
+ "measure",
211
+ "simplify",
212
+ "associative",
213
+ "composite",
214
+ "divisible",
215
+ "divisor",
216
+ "partial product",
217
+ "prime number",
218
+ "remainder",
219
+ "acute",
220
+ "arc",
221
+ "collinear",
222
+ "equilateral",
223
+ "intersect",
224
+ "isosceles",
225
+ "symmetry",
226
+ "line segment",
227
+ "line",
228
+ "obtuse",
229
+ "perpendicular",
230
+ "protractor",
231
+ "scalene",
232
+ "straight angle",
233
+ "supplementary angle",
234
+ "vertex",
235
+ "common denominator",
236
+ "denominator",
237
+ "fraction",
238
+ "mixed number",
239
+ "numerator",
240
+ "whole",
241
+ "decimal expanded form",
242
+ "decimal",
243
+ "hundredth",
244
+ "tenth",
245
+ "customary system of measurement",
246
+ "customary unit",
247
+ "gallon",
248
+ "metric",
249
+ "metric unit",
250
+ "ounce",
251
+ "pint",
252
+ "quart",
253
+ "convert",
254
+ "distance",
255
+ "millimeter",
256
+ "thousandth",
257
+ "hundredths",
258
+ "conversion factor",
259
+ "decimal fraction",
260
+ "multiplier",
261
+ "equivalence",
262
+ "multiple",
263
+ "product",
264
+ "benchmark fraction",
265
+ "cup",
266
+ "pound",
267
+ "yard",
268
+ "whole unit",
269
+ "decimal divisor",
270
+ "factors",
271
+ "bisect",
272
+ "cubic units",
273
+ "hierarchy",
274
+ "unit cube",
275
+ "attribute",
276
+ "kite",
277
+ "bisector",
278
+ "solid figure",
279
+ "square units",
280
+ "dimension",
281
+ "axis",
282
+ "ordered pair",
283
+ "angle measure",
284
+ "horizontal",
285
+ "vertical",
286
+ "categorical data",
287
+ "lcm",
288
+ "measure of center",
289
+ "meters per second",
290
+ "numerical",
291
+ "solution",
292
+ "unit price",
293
+ "unit rate",
294
+ "variability",
295
+ "variable",
296
+ ]
297
+
298
  def get_num_words(text):
299
  if not isinstance(text, str):
300
  print("%s is not a string" % text)