QA / app.py
huriacane33's picture
Update app.py
09ded81 verified
raw
history blame
2.31 kB
import streamlit as st
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
import pandas as pd
# Load the Question Answering model
@st.cache_resource
def load_qa_pipeline():
"""Load the QA pipeline with deepset/roberta-base-squad2 model."""
model_name = "deepset/roberta-base-squad2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
return pipeline("question-answering", model=model, tokenizer=tokenizer)
qa_pipeline = load_qa_pipeline()
# Load SOP Dataset
@st.cache_data
def load_sop_dataset():
"""Load SOP dataset from CSV."""
return pd.read_csv("dataset.csv") # Ensure this file is uploaded to your Hugging Face Space
dataset = load_sop_dataset()
# Utility function to find the most relevant context
def find_best_context(question, dataset):
"""Find the single best context for a given question."""
best_score = 0
best_context = None
for index, row in dataset.iterrows():
# Access the 'text' column in the row
context_text = row['text']
# Simple heuristic: Count the number of overlapping words
overlap = len(set(question.lower().split()) & set(context_text.lower().split()))
if overlap > best_score:
best_score = overlap
best_context = context_text
return best_context
# Streamlit UI
st.title("SOP Question Answering AI")
st.markdown("Ask any question about Standard Operating Procedures:")
# User input
question = st.text_area("Enter your question:", "")
# Generate answer
if st.button("Get Answer"):
if question:
with st.spinner("Finding the best context..."):
# Automatically find the most relevant context
context = find_best_context(question, dataset)
if context:
with st.spinner("Answering your question..."):
result = qa_pipeline(question=question, context=context)
st.success("Answer:")
st.write(result["answer"])
st.write("Confidence Score:", result["score"])
else:
st.warning("No relevant context found. Please try rephrasing your question.")
else:
st.warning("Please enter a question.")