aekupor commited on
Commit
638ffeb
1 Parent(s): 5cefadd

Remove logic moved into talk_move_router

Browse files
Files changed (1) hide show
  1. handler.py +1 -120
handler.py CHANGED
@@ -6,22 +6,6 @@ from datetime import datetime
6
  import torch
7
  import spacy
8
 
9
- nlp = spacy.load("en_core_web_sm")
10
- tokenizer = nlp.tokenizer
11
- token_limit = 200
12
-
13
- class Utterance(object):
14
-
15
- def __init__(self, starttime, endtime, speaker, text,
16
- idx, prev_utterance, prev_prev_utterance):
17
- self.starttime = starttime
18
- self.endtime = endtime
19
- self.speaker = speaker
20
- self.text = text
21
- self.idx = idx
22
- self.prev = prev_utterance
23
- self.prev_prev = prev_prev_utterance
24
-
25
  class EndpointHandler():
26
  def __init__(self, path="."):
27
  print("Loading models...")
@@ -30,109 +14,6 @@ class EndpointHandler():
30
  "roberta", path, use_cuda=cuda_available
31
  )
32
 
33
- def utterance_to_str(self, utterance: Utterance) -> str:
34
- # connecting only uses text
35
- doc = nlp(utterance.text)
36
- if len(doc) > token_limit:
37
- return self.handle_long_utterances(doc)
38
- return utterance.text
39
-
40
- def handle_long_utterances(self, doc: str) -> List[str]:
41
- split_count = 1
42
- total_sent = len([x for x in doc.sents])
43
- sent_count = 0
44
- token_count = 0
45
- split_utterance = ''
46
- utterances = []
47
- for sent in doc.sents:
48
- # add a sentence to split
49
- split_utterance = split_utterance + ' ' + sent.text
50
- token_count += len(sent)
51
- sent_count +=1
52
- if token_count >= token_limit or sent_count == total_sent:
53
- # save utterance segment
54
- utterances.append(split_utterance)
55
-
56
- # restart count
57
- split_utterance = ''
58
- token_count = 0
59
- split_count += 1
60
-
61
- return utterances
62
-
63
-
64
- def convert_time(self, time_str):
65
- time = datetime.strptime(time_str, "%H:%M:%S.%f")
66
- return 1000 * (3600 * time.hour + 60 * time.minute + time.second) + time.microsecond / 1000
67
-
68
- def process_vtt_transcript(self, vttfile) -> List[Utterance]:
69
- """Process raw vtt file."""
70
-
71
- utterances_list = []
72
- text = ""
73
- prev_speaker = None
74
- prev_start = "00:00:00.000"
75
- prev_end = "00:00:00.000"
76
- idx = 0
77
- prev_utterance = None
78
- prev_prev_utterance = None
79
- for caption in webvtt.read(vttfile):
80
-
81
- # Get speaker
82
- check_for_speaker = caption.text.split(":")
83
- if len(check_for_speaker) > 1: # the speaker was changed or restated
84
- speaker = check_for_speaker[0]
85
- else:
86
- speaker = prev_speaker
87
-
88
- # Get utterance
89
- new_text = check_for_speaker[1] if len(check_for_speaker) > 1 else check_for_speaker[0]
90
-
91
- # If speaker was changed, start new batch
92
- if (prev_speaker is not None) and (speaker != prev_speaker):
93
- utterance = Utterance(starttime=self.convert_time(prev_start),
94
- endtime=self.convert_time(prev_end),
95
- speaker=prev_speaker,
96
- text=text.strip(),
97
- idx=idx,
98
- prev_utterance=prev_utterance,
99
- prev_prev_utterance=prev_prev_utterance)
100
-
101
- utterances_list.append(utterance)
102
-
103
- # Start new batch
104
- prev_start = caption.start
105
- text = ""
106
- prev_prev_utterance = prev_utterance
107
- prev_utterance = utterance
108
- idx+=1
109
- text += new_text + " "
110
- prev_end = caption.end
111
- prev_speaker = speaker
112
-
113
- # Append last one
114
- if prev_speaker is not None:
115
- utterance = Utterance(starttime=self.convert_time(prev_start),
116
- endtime=self.convert_time(prev_end),
117
- speaker=prev_speaker,
118
- text=text.strip(),
119
- idx=idx,
120
- prev_utterance=prev_utterance,
121
- prev_prev_utterance=prev_prev_utterance)
122
- utterances_list.append(utterance)
123
-
124
- print(utterances_list)
125
- return utterances_list
126
-
127
-
128
  def __call__(self, data_file: str) -> List[Dict[str, Any]]:
129
  ''' data_file is a str pointing to filename of type .vtt '''
130
-
131
- utterances_list = []
132
- for utterance in self.process_vtt_transcript(data_file):
133
- #TODO: filter out to only have SL utterances
134
- utterances_list.append(self.utterance_to_str(utterance))
135
-
136
- predictions, raw_outputs = self.model.predict(utterances_list)
137
-
138
- return predictions
 
6
  import torch
7
  import spacy
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  class EndpointHandler():
10
  def __init__(self, path="."):
11
  print("Loading models...")
 
14
  "roberta", path, use_cuda=cuda_available
15
  )
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def __call__(self, data_file: str) -> List[Dict[str, Any]]:
18
  ''' data_file is a str pointing to filename of type .vtt '''
19
+ return []