Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForQuestionAnswering | |
import torch | |
# Load the tokenizer and model | |
model_name = 'roberta-base' | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForQuestionAnswering.from_pretrained('./results') # Path to your fine-tuned model | |
# Set the title for the Streamlit app | |
st.title("Movie Trivia Question Answering") | |
# Text inputs for the user | |
context = st.text_area("Enter the context (movie-related text):") | |
question = st.text_area("Enter your question:") | |
def get_answer(context, question): | |
inputs = tokenizer.encode_plus(question, context, return_tensors='pt', truncation=True, padding=True) | |
input_ids = inputs['input_ids'].tolist()[0] | |
# Get the model's answer | |
outputs = model(**inputs) | |
answer_start_scores = outputs.start_logits | |
answer_end_scores = outputs.end_logits | |
# Get the most likely beginning and end of the answer span | |
answer_start = torch.argmax(answer_start_scores) | |
answer_end = torch.argmax(answer_end_scores) + 1 | |
answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end])) | |
return answer | |
if st.button("Get Answer"): | |
if context and question: | |
answer = get_answer(context, question) | |
st.subheader("Answer") | |
st.write(answer) | |
else: | |
st.warning("Please enter both context and question.") | |
# Optionally, add instructions or information about the app | |
st.write(""" | |
Enter a movie-related context and a question related to the context above. The model will provide the answer based on the context provided. | |
""") |