Bootcamp_analysis / handler.py
YaHi's picture
Upload handler.py
c7828f5 verified
from typing import Dict, List, Any
from scipy.special import softmax
import numpy as np
import weakref
from utils import (
clean_str,
clean_str_nopunct,
MultiHeadModel,
BertInputBuilder,
get_num_words,
preprocess_transcript_for_eliciting,
preprocess_raw_files,
post_processing_output_json,
compute_student_engagement,
compute_talk_time,
gpt4_filtering_selection
)
import torch
from transformers import BertTokenizer, BertForSequenceClassification, AutoModelForSequenceClassification, AutoTokenizer
UPTAKE_MODEL='ddemszky/uptake-model'
QUESTION_MODEL ='ddemszky/question-detection'
ELICITING_MODEL = 'YaHi/teacher_electra_small'
class UptakeUtterance:
def __init__(self, speaker, text, uid=None,
transcript=None, starttime=None, endtime=None, **kwargs):
self.speaker = speaker
self.text = text
self.prev_utt = None
self.uid = uid
self.starttime = starttime
self.endtime = endtime
self.transcript = weakref.ref(transcript) if transcript else None
self.props = kwargs
self.uptake = None
self.question = None
def get_clean_text(self, remove_punct=False):
if remove_punct:
return clean_str_nopunct(self.text)
return clean_str(self.text)
def get_num_words(self):
if self.text is None:
return 0
return get_num_words(self.text)
def to_dict(self):
return {
'speaker': self.speaker,
'text': self.text,
'prev_utt': self.prev_utt,
'uid': self.uid,
'starttime': self.starttime,
'endtime': self.endtime,
'uptake': self.uptake,
'question': self.question,
**self.props
}
def __repr__(self):
return f"Utterance(speaker='{self.speaker}'," \
f"text='{self.text}', prev_utt='{self.prev_utt}', uid={self.uid}," \
f"starttime={self.starttime}, endtime={self.endtime}, props={self.props})"
class UptakeTranscript:
def __init__(self, **kwargs):
self.utterances = []
self.params = kwargs
def add_utterance(self, utterance):
utterance.transcript = weakref.ref(self)
self.utterances.append(utterance)
def get_idx(self, idx):
if idx >= len(self.utterances):
return None
return self.utterances[idx]
def get_uid(self, uid):
for utt in self.utterances:
if utt.uid == uid:
return utt
return None
def length(self):
return len(self.utterances)
def to_dict(self):
return {
'utterances': [utterance.to_dict() for utterance in self.utterances],
**self.params
}
def __repr__(self):
return f"Transcript(utterances={self.utterances}, custom_params={self.params})"
class ElicitingUtterance:
def __init__(self, speaker, text, starttime, endtime, uid=None, transcript=None, prev_utt=None):
self.speaker = speaker
self.text = clean_str_nopunct(text)
self.uid = uid
self.transcript = transcript if transcript else None
self.prev_utt = prev_utt
self.eliciting = None
self.question = None
self.starttime = starttime
self.endtime = endtime
def __setitem__(self, key, value):
self.__dict__[key] = value
def get_clean_text(self, remove_punct=False):
if remove_punct:
return clean_str_nopunct(self.text)
return clean_str(self.text)
def to_dict(self):
return {
'speaker': self.speaker,
'text': self.text,
'uid': self.uid,
'prev_utt': self.prev_utt,
'eliciting': self.eliciting,
'question': self.question,
'starttime': self.starttime,
'endtime': self.endtime,
}
def __repr__(self):
return f"Utterance(speaker='{self.speaker}'," \
f"text='{self.text}', uid={self.uid}, prev_utt={self.prev_utt}, elicting={self.eliciting}, question={self.question}), starttime={self.starttime}, endtime={self.endtime})"
class ElicitingTranscript:
def __init__(self, utterances: List[ElicitingUtterance], tokenizer=None):
self.tokenizer = tokenizer
self.utterances = []
prev_utt = ""
prev_utt_teacher = ""
prev_speaker = None
for utterance in utterances:
try:
if 'student' in utterance["speaker"]:
utterance["speaker"] = 'student'
except:
continue
if (prev_speaker == 'tutor') and (utterance["speaker"] == 'student'):
utterance = ElicitingUtterance(**utterance, transcript=self, prev_utt=prev_utt.text)
elif (prev_speaker == 'student') and (utterance["speaker"] == 'tutor'):
utterance = ElicitingUtterance(**utterance, transcript=self, prev_utt=prev_utt.text)
prev_utt_teacher = utterance.text
elif (prev_speaker == 'student') and (utterance["speaker"] == 'student'):
try:
utterance = ElicitingUtterance(**utterance, transcript=self, prev_utt=prev_utt_teacher)
except:
print("Error on line 159 of handler.py")
print(utterance)
# breakpoint()
else:
utterance = ElicitingUtterance(**utterance, transcript=self, prev_utt="")
if utterance.speaker == 'tutor':
prev_utt_teacher = utterance.text
prev_utt = utterance
prev_speaker = utterance.speaker
self.utterances.append(utterance)
def __len__(self):
return len(self.utterances)
def __getitem__(self, index):
output = self.tokenizer([(self.utterances[index].prev_utt, self.utterances[index].text)], truncation=True)
output["speaker"] = self.utterances[index].speaker
output["uid"] = self.utterances[index].uid
output["prev_utt"] = self.utterances[index].prev_utt
output["text"] = self.utterances[index].text
return output
def to_dict(self):
return {
'utterances': [utterance.to_dict() for utterance in self.utterances]
}
class QuestionModel:
def __init__(self, device, tokenizer, input_builder, max_length=300, path=QUESTION_MODEL):
print("Loading models...")
self.device = device
self.tokenizer = tokenizer
self.input_builder = input_builder
self.max_length = max_length
self.model = MultiHeadModel.from_pretrained(path, head2size={"is_question": 2})
self.model.to(self.device)
def run_inference(self, transcript):
self.model.eval()
with torch.no_grad():
for i, utt in enumerate(transcript.utterances):
if utt.text is None:
utt.question = None
continue
if "?" in utt.text:
utt.question = 1
else:
text = utt.get_clean_text(remove_punct=True)
instance = self.input_builder.build_inputs([], text,
max_length=self.max_length,
input_str=True)
output = self.get_prediction(instance)
utt.question = softmax(output["is_question_logits"][0].tolist())[1]
def get_prediction(self, instance):
instance["attention_mask"] = [[1] * len(instance["input_ids"])]
for key in ["input_ids", "token_type_ids", "attention_mask"]:
instance[key] = torch.tensor(instance[key]).unsqueeze(0) # Batch size = 1
instance[key].to(self.device)
output = self.model(input_ids=instance["input_ids"].to(self.device),
attention_mask=instance["attention_mask"].to(self.device),
token_type_ids=instance["token_type_ids"].to(self.device),
return_pooler_output=False)
return output
class UptakeModel:
def __init__(self, device, tokenizer, input_builder, max_length=120, path=UPTAKE_MODEL):
print("Loading models...")
self.device = device
self.tokenizer = tokenizer
self.input_builder = input_builder
self.max_length = max_length
self.model = MultiHeadModel.from_pretrained(path, head2size={"nsp": 2})
self.model.to(self.device)
def run_inference(self, transcript, min_prev_words, uptake_speaker=None):
self.model.eval()
prev_num_words = 0
prev_utt = None
with torch.no_grad():
for i, utt in enumerate(transcript.utterances):
if ((uptake_speaker is None) or (utt.speaker == uptake_speaker)) and (prev_num_words >= min_prev_words):
textA = prev_utt.get_clean_text(remove_punct=False)
textB = utt.get_clean_text(remove_punct=False)
instance = self.input_builder.build_inputs([textA], textB,
max_length=self.max_length,
input_str=True)
output = self.get_prediction(instance)
utt.uptake = softmax(output["nsp_logits"][0].tolist())[1]
utt.prev_utt = prev_utt.text
prev_num_words = utt.get_num_words()
prev_utt = utt
def get_prediction(self, instance):
instance["attention_mask"] = [[1] * len(instance["input_ids"])]
for key in ["input_ids", "token_type_ids", "attention_mask"]:
instance[key] = torch.tensor(instance[key]).unsqueeze(0) # Batch size = 1
instance[key].to(self.device)
output = self.model(input_ids=instance["input_ids"].to(self.device),
attention_mask=instance["attention_mask"].to(self.device),
token_type_ids=instance["token_type_ids"].to(self.device),
return_pooler_output=False)
return output
class ElicitingModel:
def __init__(self, device, tokenizer, path=ELICITING_MODEL):
print("Loading teacher models...")
self.device = device
self.tokenizer = tokenizer
self.model = AutoModelForSequenceClassification.from_pretrained(path).to(self.device)
def run_inference(self, dataset):
current_batch = 0
batch_size = 64
def generator():
while current_batch < len(dataset):
yield
for _ in generator():
# check if the remaining samples are less than the batch size
if len(dataset) - current_batch < batch_size:
batch_size = len(dataset) - current_batch
to_pad = [{"input_ids": example["input_ids"][0], "attention_mask": example["attention_mask"][0]} for example in dataset]
to_pad = to_pad[current_batch:current_batch + batch_size]
batch = self.tokenizer.pad(
to_pad,
padding=True,
max_length=None,
pad_to_multiple_of=None,
return_tensors="pt",
)
inputs = batch["input_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
with torch.no_grad():
outputs = self.model(inputs, attention_mask=attention_mask)
predictions = outputs.logits.argmax(dim=-1).cpu().numpy()
for i, prediction in enumerate(predictions):
if dataset.utterances[current_batch + i].speaker == 'tutor':
dataset.utterances[current_batch + i]["eliciting"] = prediction
current_batch += batch_size
class EndpointHandler():
def __init__(self, path="."):
print("Loading models...")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
self.input_builder = BertInputBuilder(tokenizer=self.tokenizer)
self.uptake_model = UptakeModel(self.device, self.tokenizer, self.input_builder)
self.question_model = QuestionModel(self.device, self.tokenizer, self.input_builder)
self.eliciting_tokenizer = AutoTokenizer.from_pretrained(ELICITING_MODEL)
self.eliciting_model = ElicitingModel(self.device, self.tokenizer, path=ELICITING_MODEL)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `list`):
List of dicts, where each dict represents an utterance; each utterance object must have a `speaker`,
`text` and `uid`and can include list of custom properties
parameters (:obj: `dict`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# get inputs
utterances = data.pop("inputs", data)
params = data.pop("parameters", None) #TODO: make sure that it includes everything required
print(params["session_uuid"])
# pre-processing
utterances = preprocess_raw_files(utterances, params)
# compute student engagement and talk time metrics
num_students_engaged, num_students_engaged_talk_only = compute_student_engagement(utterances)
tutor_talk_time = compute_talk_time(utterances)
#TODO: make sure there is some routing going on here based on what session we are at
if params["session_type"] == "eliciting":
# pre-processing for eliciting
utterances_elicting = preprocess_transcript_for_eliciting(utterances)
eliciting_transcript = ElicitingTranscript(utterances_elicting, tokenizer=self.tokenizer)
self.eliciting_model.run_inference(eliciting_transcript)
# Question
self.question_model.run_inference(eliciting_transcript)
transcript_output = eliciting_transcript
else:
uptake_transcript = UptakeTranscript(filename=params.pop("filename", None))
for utt in utterances:
uptake_transcript.add_utterance(UptakeUtterance(**utt))
# Uptake
self.uptake_model.run_inference(uptake_transcript, min_prev_words=params['uptake_min_num_words'],
uptake_speaker=params.pop("uptake_speaker", None))
# Question
self.question_model.run_inference(uptake_transcript)
transcript_output = uptake_transcript
# post-processing
model_outputs = post_processing_output_json(transcript_output.to_dict(), params["session_uuid"], params["session_type"])
final_output = {}
final_output["metrics"] = {"num_students_engaged": num_students_engaged,
"num_students_engaged_talk_only": num_students_engaged_talk_only,
"tutor_talk_time": tutor_talk_time}
if len(model_outputs) > 0:
model_outputs = gpt4_filtering_selection(model_outputs, params["session_type"], params["focus_concept"])
final_output["model_outputs"] = model_outputs
final_output["event_id"] = params["event_id"]
import requests
webhooks_url = 'https://schoolhouse.world/api/webhooks/stanford-ai-feedback-highlights'
response = requests.post(webhooks_url, json=final_output)
print("Post request sent, here is the response: ", response)
return final_output