Morris
update app.py
85e0b67
import streamlit as st
import pandas as pd
from PIL import Image
from transformers import pipeline, AutoTokenizer
from datasets import load_dataset, Dataset
import os
from supabase import create_client, Client
url: str = os.environ.get("SUPABASE_URL")
key: str = os.environ.get("SUPABASE_KEY")
supabase: Client = create_client(url, key)
pipe = pipeline('text-classification', model='tiedaar/short-answer-classification')
bleurt_pipe = pipeline('text-classification', model="vaiibhavgupta/finetuned-bleurt-large")
longformer_pipe = pipeline('text-classification', model='tiedaar/short-answer-classification-longformer')
tokenizer = AutoTokenizer.from_pretrained("vaiibhavgupta/finetuned-bleurt-large")
subsections = pd.read_csv('tp-subsections.csv').dropna(axis=0, how='any', subset=['question'])
def reset():
st.session_state.ind = False
st.session_state.student_answer = False
def submit():
entry = {"subsection":st.session_state.ind, "source":passage, "question":question, "answer":st.session_state.student_answer, "mpnet_response":mpnet_res, "bleurt_response":bleurt_res, "correct_response":is_correct}
data, count = supabase.table('automatic-short-answer-scoring') \
.insert(entry).execute()
def get_mpnet(candidate, reference):
text = candidate + '</s>' + reference
res = pipe(text)[0]['label']
return res
def get_longform(candidate, reference):
text = candidate + '</s>' + reference
res = longformer_pipe(text)[0]['label']
return res
def get_bleurt(candidate, reference):
text = candidate + tokenizer.sep_token + reference
score = bleurt_pipe(text)[0]['score']
if score > 0.7:
return 'correct_answer'
else:
return 'incorrect_answer'
st.title('iTELL Short Answer Scoring Demo')
st.image(Image.open('learlabaialoe.JPG'))
st.subheader('This is a demonstration of the iTELL short answer scoring model. You will be provided with a passage from the textbook Think Python and a question. Please provide a short answer to the question.')
st.slider('Use this slider to choose your subsection.', min_value=0, max_value=len(subsections), key='ind')
if st.session_state.ind:
student_answer = False
passage = subsections.iloc[st.session_state.ind]['clean_text']
question = subsections.iloc[st.session_state.ind]['question']
answer = subsections.iloc[st.session_state.ind]['answer']
st.markdown('---')
st.header('Passage')
st.write(passage)
st.markdown('---')
st.header('Question')
st.write(question)
is_correct = st.radio("Are you writing a correct answer?", ["Yes", "No"])
st.text_input("Write your answer here", key='student_answer')
if st.session_state.student_answer:
if is_correct:
mpnet_res = get_mpnet(st.session_state.student_answer, answer)
bleurt_res = get_bleurt(st.session_state.student_answer, answer)
longform_res = get_longform(st.session_state.student_answer,answer)
col1, col2, col3 = st.columns(3)
with col1:
if mpnet_res == 'correct_answer':
st.subheader('MPnet says yes!')
st.image(Image.open('congratulations-meme.jpeg'))
st.write('Yay, you got it right!')
elif mpnet_res == 'incorrect_answer':
st.subheader('MPnet says no!')
st.write('Nope, you said', st.session_state.student_answer)
st.write('A better answer would have been: ', answer)
with col2:
if bleurt_res == 'correct_answer':
st.subheader('Bleurt says yes!')
st.image(Image.open('congratulations-meme.jpeg'))
st.write('Yay, you got it right!')
elif bleurt_res == 'incorrect_answer':
st.subheader('Bleurt says no!')
st.write('Nope, you said', st.session_state.student_answer)
st.write('A better answer would have been: ', answer)
with col3:
if longform_res == 'correct_answer':
st.subheader('Longformer says yes!')
st.image(Image.open('congratulations-meme.jpeg'))
st.write('Yay, you got it right!')
elif longform_res == 'incorrect_answer':
st.subheader('Longformer says no!')
st.write('Nope, you said', st.session_state.student_answer)
st.write('A better answer would have been: ', answer)
st.button('Reset', on_click=reset)
st.button('Submit', on_click=submit)
else:
st.write("Please indicate whether you are providing a correct answer.")