only run reasoning for non-teacher utterances
Browse files- handler.py +3 -3
handler.py
CHANGED
@@ -256,11 +256,11 @@ class ReasoningModel:
|
|
256 |
self.model = BertForSequenceClassification.from_pretrained(path)
|
257 |
self.model.to(self.device)
|
258 |
|
259 |
-
def run_inference(self, transcript, min_num_words=8):
|
260 |
self.model.eval()
|
261 |
with torch.no_grad():
|
262 |
for i, utt in enumerate(transcript.utterances):
|
263 |
-
if utt.get_num_words() >= min_num_words:
|
264 |
instance = self.input_builder.build_inputs([], utt.text,
|
265 |
max_length=self.max_length,
|
266 |
input_str=True)
|
@@ -430,7 +430,7 @@ class EndpointHandler():
|
|
430 |
# Reasoning
|
431 |
reasoning_model = ReasoningModel(
|
432 |
self.device, self.tokenizer, self.input_builder)
|
433 |
-
reasoning_model.run_inference(transcript)
|
434 |
|
435 |
# Question
|
436 |
question_model = QuestionModel(
|
|
|
256 |
self.model = BertForSequenceClassification.from_pretrained(path)
|
257 |
self.model.to(self.device)
|
258 |
|
259 |
+
def run_inference(self, transcript, min_num_words=8, uptake_speaker=None):
|
260 |
self.model.eval()
|
261 |
with torch.no_grad():
|
262 |
for i, utt in enumerate(transcript.utterances):
|
263 |
+
if utt.get_num_words() >= min_num_words and utt.speaker != uptake_speaker:
|
264 |
instance = self.input_builder.build_inputs([], utt.text,
|
265 |
max_length=self.max_length,
|
266 |
input_str=True)
|
|
|
430 |
# Reasoning
|
431 |
reasoning_model = ReasoningModel(
|
432 |
self.device, self.tokenizer, self.input_builder)
|
433 |
+
reasoning_model.run_inference(transcript, uptake_speaker=uptake_speaker)
|
434 |
|
435 |
# Question
|
436 |
question_model = QuestionModel(
|