aekupor commited on
Commit
94ef6d2
1 Parent(s): 5499fc9

Remove code (moved to talk move router)

Browse files
Files changed (1) hide show
  1. handler.py +1 -158
handler.py CHANGED
@@ -10,18 +10,6 @@ nlp = spacy.load("en_core_web_sm")
10
  tokenizer = nlp.tokenizer
11
  token_limit = 200
12
 
13
- class Utterance(object):
14
-
15
- def __init__(self, starttime, endtime, speaker, text,
16
- idx, prev_utterance, prev_prev_utterance):
17
- self.starttime = starttime
18
- self.endtime = endtime
19
- self.speaker = speaker
20
- self.text = text
21
- self.idx = idx
22
- self.prev_utterance = prev_utterance
23
- self.prev_prev_utterance = prev_prev_utterance
24
-
25
  class EndpointHandler():
26
  def __init__(self, path="."):
27
  print("Loading models...")
@@ -30,151 +18,6 @@ class EndpointHandler():
30
  "roberta", path, use_cuda=cuda_available
31
  )
32
 
33
- def utterance_to_str(self, utterance: Utterance) -> (List[str], str):
34
- #revoicing using prior text and truncates end of the prior text
35
-
36
- doc = nlp(utterance.text)
37
- prior_text = self.truncate_end(self.get_prior_text(utterance))
38
-
39
- if len(doc) > token_limit:
40
- utterance_text_list = self.handle_long_utterances(doc)
41
- utterance_with_prior_text = []
42
- for text in utterance_text_list:
43
- utterance_with_prior_text.append([prior_text, text])
44
- return utterance_with_prior_text, 'list'
45
-
46
- else:
47
- return [prior_text, utterance.text], 'single'
48
-
49
- def truncate_end(self, prior_text: str) -> str:
50
- max_seq_length = 512
51
- prior_text_max_length = int(max_seq_length / 2) #divide by 2 because 2 columns
52
-
53
- if len(prior_text) > prior_text_max_length:
54
- starting_index = len(prior_text) - prior_text_max_length
55
- return prior_text[starting_index:]
56
- return prior_text
57
-
58
- def format_speaker(self, speaker: str, source: str) -> str:
59
- prior_text = ''
60
- if speaker == 'student':
61
- prior_text += '***STUDENT '
62
- else:
63
- prior_text += '***SECTION_LEADER '
64
- if source == 'not chat':
65
- prior_text += '(audio)*** : '
66
- else:
67
- prior_text += '(chat)*** : '
68
- return prior_text
69
-
70
- def get_prior_text(self, utterance: Utterance) -> str:
71
- prior_text = ''
72
- if utterance.prev_utterance != None and utterance.prev_prev_utterance != None:
73
- #TODO: add in the source
74
- prior_text = '\"' + self.format_speaker(utterance.prev_prev_utterance.speaker, 'not chat') + utterance.prev_prev_utterance.text + ' \n '
75
- prior_text += self.format_speaker(utterance.prev_utterance.speaker, 'not chat') + utterance.prev_utterance.text + ' \n '
76
- else:
77
- prior_text = 'No prior utterance'
78
- return prior_text
79
-
80
- def handle_long_utterances(self, doc: str) -> List[str]:
81
- split_count = 1
82
- total_sent = len([x for x in doc.sents])
83
- sent_count = 0
84
- token_count = 0
85
- split_utterance = ''
86
- utterances = []
87
- for sent in doc.sents:
88
- # add a sentence to split
89
- split_utterance = split_utterance + ' ' + sent.text
90
- token_count += len(sent)
91
- sent_count +=1
92
- if token_count >= token_limit or sent_count == total_sent:
93
- # save utterance segment
94
- utterances.append(split_utterance)
95
-
96
- # restart count
97
- split_utterance = ''
98
- token_count = 0
99
- split_count += 1
100
-
101
- return utterances
102
-
103
- def convert_time(self, time_str):
104
- time = datetime.strptime(time_str, "%H:%M:%S.%f")
105
- return 1000 * (3600 * time.hour + 60 * time.minute + time.second) + time.microsecond / 1000
106
-
107
- def process_vtt_transcript(self, vttfile) -> List[Utterance]:
108
- """Process raw vtt file."""
109
-
110
- utterances_list = []
111
- text = ""
112
- prev_start = "00:00:00.000"
113
- prev_end = "00:00:00.000"
114
- idx = 0
115
- prev_speaker = None
116
- prev_utterance = None
117
- prev_prev_utterance = None
118
- for caption in webvtt.read(vttfile):
119
-
120
- # Get speaker
121
- check_for_speaker = caption.text.split(":")
122
- if len(check_for_speaker) > 1: # the speaker was changed or restated
123
- speaker = check_for_speaker[0]
124
- else:
125
- speaker = prev_speaker
126
-
127
- # Get utterance
128
- new_text = check_for_speaker[1] if len(check_for_speaker) > 1 else check_for_speaker[0]
129
-
130
- # If speaker was changed, start new batch
131
- if (prev_speaker is not None) and (speaker != prev_speaker):
132
- utterance = Utterance(starttime=self.convert_time(prev_start),
133
- endtime=self.convert_time(prev_end),
134
- speaker=prev_speaker,
135
- text=text.strip(),
136
- idx=idx,
137
- prev_utterance=prev_utterance,
138
- prev_prev_utterance=prev_prev_utterance)
139
-
140
- utterances_list.append(utterance)
141
-
142
- # Start new batch
143
- prev_start = caption.start
144
- text = ""
145
- prev_prev_utterance = prev_utterance
146
- prev_utterance = utterance
147
- idx+=1
148
- text += new_text + " "
149
- prev_end = caption.end
150
- prev_speaker = speaker
151
-
152
- # Append last one
153
- if prev_speaker is not None:
154
- utterance = Utterance(starttime=self.convert_time(prev_start),
155
- endtime=self.convert_time(prev_end),
156
- speaker=prev_speaker,
157
- text=text.strip(),
158
- idx=idx,
159
- prev_utterance=prev_utterance,
160
- prev_prev_utterance=prev_prev_utterance)
161
- utterances_list.append(utterance)
162
-
163
- return utterances_list
164
-
165
-
166
  def __call__(self, data_file: str) -> List[Dict[str, Any]]:
167
  ''' data_file is a str pointing to filename of type .vtt '''
168
-
169
- utterances_list = []
170
- for utterance in self.process_vtt_transcript(data_file):
171
- #TODO: filter out to only have SL utterances
172
- utterance_str, is_list = self.utterance_to_str(utterance)
173
- if is_list == 'list':
174
- utterances_list.extend(utterance_str)
175
- else:
176
- utterances_list.append(utterance_str)
177
-
178
- predictions, raw_outputs = self.model.predict(utterances_list)
179
-
180
- return predictions
 
10
  tokenizer = nlp.tokenizer
11
  token_limit = 200
12
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  class EndpointHandler():
14
  def __init__(self, path="."):
15
  print("Loading models...")
 
18
  "roberta", path, use_cuda=cuda_available
19
  )
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def __call__(self, data_file: str) -> List[Dict[str, Any]]:
22
  ''' data_file is a str pointing to filename of type .vtt '''
23
+ return []