aekupor commited on
Commit
6226ff4
1 Parent(s): a7d2f40

Remove code (moved to talk move handler)

Browse files
Files changed (1) hide show
  1. handler.py +1 -153
handler.py CHANGED
@@ -6,22 +6,6 @@ from datetime import datetime
6
  import torch
7
  import spacy
8
 
9
- nlp = spacy.load("en_core_web_sm")
10
- tokenizer = nlp.tokenizer
11
- token_limit = 200
12
-
13
- class Utterance(object):
14
-
15
- def __init__(self, starttime, endtime, speaker, text,
16
- idx, prev_utterance, prev_prev_utterance):
17
- self.starttime = starttime
18
- self.endtime = endtime
19
- self.speaker = speaker
20
- self.text = text
21
- self.idx = idx
22
- self.prev_utterance = prev_utterance
23
- self.prev_prev_utterance = prev_prev_utterance
24
-
25
  class EndpointHandler():
26
  def __init__(self, path="."):
27
  print("Loading models...")
@@ -30,142 +14,6 @@ class EndpointHandler():
30
  "roberta", path, use_cuda=cuda_available
31
  )
32
 
33
- def utterance_to_str(self, utterance: Utterance) -> (List[str], str):
34
- #model utterance uses prior text
35
-
36
- doc = nlp(utterance.text)
37
- prior_text = self.get_prior_text(utterance)
38
-
39
- if len(doc) > token_limit:
40
- utterance_text_list = self.handle_long_utterances(doc)
41
- utterance_with_prior_text = []
42
- for text in utterance_text_list:
43
- utterance_with_prior_text.append([prior_text, text])
44
- return utterance_with_prior_text, 'list'
45
-
46
- else:
47
- return [prior_text, utterance.text], 'single'
48
-
49
- def format_speaker(self, speaker: str, source: str) -> str:
50
- prior_text = ''
51
- if speaker == 'student':
52
- prior_text += '***STUDENT '
53
- else:
54
- prior_text += '***SECTION_LEADER '
55
- if source == 'not chat':
56
- prior_text += '(audio)*** : '
57
- else:
58
- prior_text += '(chat)*** : '
59
- return prior_text
60
-
61
- def get_prior_text(self, utterance: Utterance) -> str:
62
- prior_text = ''
63
- if utterance.prev_utterance != None and utterance.prev_prev_utterance != None:
64
- #TODO: add in the source
65
- prior_text = '\"' + self.format_speaker(utterance.prev_prev_utterance.speaker, 'not chat') + utterance.prev_prev_utterance.text + ' \n '
66
- prior_text += self.format_speaker(utterance.prev_utterance.speaker, 'not chat') + utterance.prev_utterance.text + ' \n '
67
- else:
68
- prior_text = 'No prior utterance'
69
- return prior_text
70
-
71
- def handle_long_utterances(self, doc: str) -> List[str]:
72
- split_count = 1
73
- total_sent = len([x for x in doc.sents])
74
- sent_count = 0
75
- token_count = 0
76
- split_utterance = ''
77
- utterances = []
78
- for sent in doc.sents:
79
- # add a sentence to split
80
- split_utterance = split_utterance + ' ' + sent.text
81
- token_count += len(sent)
82
- sent_count +=1
83
- if token_count >= token_limit or sent_count == total_sent:
84
- # save utterance segment
85
- utterances.append(split_utterance)
86
-
87
- # restart count
88
- split_utterance = ''
89
- token_count = 0
90
- split_count += 1
91
-
92
- return utterances
93
-
94
- def convert_time(self, time_str):
95
- time = datetime.strptime(time_str, "%H:%M:%S.%f")
96
- return 1000 * (3600 * time.hour + 60 * time.minute + time.second) + time.microsecond / 1000
97
-
98
- def process_vtt_transcript(self, vttfile) -> List[Utterance]:
99
- """Process raw vtt file."""
100
-
101
- utterances_list = []
102
- text = ""
103
- prev_start = "00:00:00.000"
104
- prev_end = "00:00:00.000"
105
- idx = 0
106
- prev_speaker = None
107
- prev_utterance = None
108
- prev_prev_utterance = None
109
- for caption in webvtt.read(vttfile):
110
-
111
- # Get speaker
112
- check_for_speaker = caption.text.split(":")
113
- if len(check_for_speaker) > 1: # the speaker was changed or restated
114
- speaker = check_for_speaker[0]
115
- else:
116
- speaker = prev_speaker
117
-
118
- # Get utterance
119
- new_text = check_for_speaker[1] if len(check_for_speaker) > 1 else check_for_speaker[0]
120
-
121
- # If speaker was changed, start new batch
122
- if (prev_speaker is not None) and (speaker != prev_speaker):
123
- utterance = Utterance(starttime=self.convert_time(prev_start),
124
- endtime=self.convert_time(prev_end),
125
- speaker=prev_speaker,
126
- text=text.strip(),
127
- idx=idx,
128
- prev_utterance=prev_utterance,
129
- prev_prev_utterance=prev_prev_utterance)
130
-
131
- utterances_list.append(utterance)
132
-
133
- # Start new batch
134
- prev_start = caption.start
135
- text = ""
136
- prev_prev_utterance = prev_utterance
137
- prev_utterance = utterance
138
- idx+=1
139
- text += new_text + " "
140
- prev_end = caption.end
141
- prev_speaker = speaker
142
-
143
- # Append last one
144
- if prev_speaker is not None:
145
- utterance = Utterance(starttime=self.convert_time(prev_start),
146
- endtime=self.convert_time(prev_end),
147
- speaker=prev_speaker,
148
- text=text.strip(),
149
- idx=idx,
150
- prev_utterance=prev_utterance,
151
- prev_prev_utterance=prev_prev_utterance)
152
- utterances_list.append(utterance)
153
-
154
- return utterances_list
155
-
156
-
157
  def __call__(self, data_file: str) -> List[Dict[str, Any]]:
158
  ''' data_file is a str pointing to filename of type .vtt '''
159
-
160
- utterances_list = []
161
- for utterance in self.process_vtt_transcript(data_file):
162
- #TODO: filter out to only have SL utterances
163
- utterance_str, is_list = self.utterance_to_str(utterance)
164
- if is_list == 'list':
165
- utterances_list.extend(utterance_str)
166
- else:
167
- utterances_list.append(utterance_str)
168
-
169
- predictions, raw_outputs = self.model.predict(utterances_list)
170
-
171
- return predictions
 
6
  import torch
7
  import spacy
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  class EndpointHandler():
10
  def __init__(self, path="."):
11
  print("Loading models...")
 
14
  "roberta", path, use_cuda=cuda_available
15
  )
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def __call__(self, data_file: str) -> List[Dict[str, Any]]:
18
  ''' data_file is a str pointing to filename of type .vtt '''
19
+ return []