File size: 4,770 Bytes
fb0aba9 b71e2da 36c82f7 09a170f b522a7f 33be775 35d2f71 33be775 35d2f71 0b82552 db80f93 c0ce392 db80f93 f24004c 85e0b67 8e5b6d9 6fd0e0b db80f93 c0ce392 db80f93 c0ce392 db80f93 c0ce392 f24004c 6fd0e0b db80f93 1a9e7d4 d4f0372 e564623 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import streamlit as st
import pandas as pd
from PIL import Image
from transformers import pipeline, AutoTokenizer
from datasets import load_dataset, Dataset
import os
from supabase import create_client, Client
url: str = os.environ.get("SUPABASE_URL")
key: str = os.environ.get("SUPABASE_KEY")
supabase: Client = create_client(url, key)
pipe = pipeline('text-classification', model='tiedaar/short-answer-classification')
bleurt_pipe = pipeline('text-classification', model="vaiibhavgupta/finetuned-bleurt-large")
longformer_pipe = pipeline('text-classification', model='tiedaar/short-answer-classification-longformer')
tokenizer = AutoTokenizer.from_pretrained("vaiibhavgupta/finetuned-bleurt-large")
subsections = pd.read_csv('tp-subsections.csv').dropna(axis=0, how='any', subset=['question'])
def reset():
st.session_state.ind = False
st.session_state.student_answer = False
def submit():
entry = {"subsection":st.session_state.ind, "source":passage, "question":question, "answer":st.session_state.student_answer, "mpnet_response":mpnet_res, "bleurt_response":bleurt_res, "correct_response":is_correct}
data, count = supabase.table('automatic-short-answer-scoring') \
.insert(entry).execute()
def get_mpnet(candidate, reference):
text = candidate + '</s>' + reference
res = pipe(text)[0]['label']
return res
def get_longform(candidate, reference):
text = candidate + '</s>' + reference
res = longformer_pipe(text)[0]['label']
return res
def get_bleurt(candidate, reference):
text = candidate + tokenizer.sep_token + reference
score = bleurt_pipe(text)[0]['score']
if score > 0.7:
return 'correct_answer'
else:
return 'incorrect_answer'
st.title('iTELL Short Answer Scoring Demo')
st.image(Image.open('learlabaialoe.JPG'))
st.subheader('This is a demonstration of the iTELL short answer scoring model. You will be provided with a passage from the textbook Think Python and a question. Please provide a short answer to the question.')
st.slider('Use this slider to choose your subsection.', min_value=0, max_value=len(subsections), key='ind')
if st.session_state.ind:
student_answer = False
passage = subsections.iloc[st.session_state.ind]['clean_text']
question = subsections.iloc[st.session_state.ind]['question']
answer = subsections.iloc[st.session_state.ind]['answer']
st.markdown('---')
st.header('Passage')
st.write(passage)
st.markdown('---')
st.header('Question')
st.write(question)
is_correct = st.radio("Are you writing a correct answer?", ["Yes", "No"])
st.text_input("Write your answer here", key='student_answer')
if st.session_state.student_answer:
if is_correct:
mpnet_res = get_mpnet(st.session_state.student_answer, answer)
bleurt_res = get_bleurt(st.session_state.student_answer, answer)
longform_res = get_longform(st.session_state.student_answer,answer)
col1, col2, col3 = st.columns(3)
with col1:
if mpnet_res == 'correct_answer':
st.subheader('MPnet says yes!')
st.image(Image.open('congratulations-meme.jpeg'))
st.write('Yay, you got it right!')
elif mpnet_res == 'incorrect_answer':
st.subheader('MPnet says no!')
st.write('Nope, you said', st.session_state.student_answer)
st.write('A better answer would have been: ', answer)
with col2:
if bleurt_res == 'correct_answer':
st.subheader('Bleurt says yes!')
st.image(Image.open('congratulations-meme.jpeg'))
st.write('Yay, you got it right!')
elif bleurt_res == 'incorrect_answer':
st.subheader('Bleurt says no!')
st.write('Nope, you said', st.session_state.student_answer)
st.write('A better answer would have been: ', answer)
with col3:
if longform_res == 'correct_answer':
st.subheader('Longformer says yes!')
st.image(Image.open('congratulations-meme.jpeg'))
st.write('Yay, you got it right!')
elif longform_res == 'incorrect_answer':
st.subheader('Longformer says no!')
st.write('Nope, you said', st.session_state.student_answer)
st.write('A better answer would have been: ', answer)
st.button('Reset', on_click=reset)
st.button('Submit', on_click=submit)
else:
st.write("Please indicate whether you are providing a correct answer.")
|