Spaces:
Runtime error
Runtime error
import streamlit as st | |
import os | |
from Article import Article | |
from VectorDB import VectorDB | |
from QuestionAnswer import QuestionAnswer | |
from transformers import AutoTokenizer, AutoModelForQuestionAnswering | |
from sentence_transformers import models, SentenceTransformer | |
reader = AutoModelForQuestionAnswering.from_pretrained('Pennywise881/distilbert-base-uncased-finetuned-squad-v2') | |
tokenizer = AutoTokenizer.from_pretrained('Pennywise881/distilbert-base-uncased-finetuned-squad-v2') | |
distilbert = models.Transformer("Pennywise881/distilbert-base-uncased-mnr-squadv2") | |
pooler = models.Pooling( | |
distilbert.get_word_embedding_dimension(), | |
pooling_mode_mean_tokens=True | |
) | |
retreiver = SentenceTransformer(modules=[distilbert, pooler]) | |
if 'found_article' not in st.session_state: | |
st.session_state.found_article = False | |
st.session_state.article_name = '' | |
st.session_state.db = None | |
st.session_state.qas = [] | |
st.write(""" | |
# Wiki Chat V2 | |
""") | |
placeholder = st.empty() | |
def get_article(retreiver): | |
article_name = placeholder.text_input("Enter the name of a Wikipedia article") | |
if article_name: | |
article = Article() | |
article_data = article.get_article_data(article_name=article_name) | |
if len(article_data) > 0: | |
API_KEY = os.environ['API_KEY'] | |
db = VectorDB(retreiver=retreiver, API_KEY=API_KEY) | |
db.upsert_data(article_data=article_data) | |
ask_questions(article_name=article_name, db=db) | |
st.session_state.found_article = True | |
st.session_state.article_name = article_name | |
st.session_state.db = db | |
else: | |
st.write(f'Sorry, could not find Wikipedia article: {article_name}') | |
def ask_questions(article_name, db : VectorDB): | |
question = placeholder.text_input(f"Ask questions about '{article_name}'", '') | |
st.header("Questions and Answers:") | |
if question: | |
contexts = db.get_contexts(question.lower()) | |
# print(contexts) | |
data = { | |
'question': question.lower(), | |
'context': contexts['matches'][0]['metadata']['text'] | |
} | |
qa = QuestionAnswer(data, reader, tokenizer, 'cpu') | |
results = qa.get_results() | |
paragraph_index = contexts['matches'][0]['id'] | |
section = contexts['matches'][0]['metadata']['section'] | |
answer = '' | |
for r in results: | |
answer += r['text'] + ", " | |
answer = answer[:len(answer) - 2] | |
st.session_state.qas.append( | |
{ | |
'question': question, | |
'answer': answer, | |
'section': section, | |
'para': paragraph_index | |
} | |
) | |
if len(st.session_state.qas) > 0: | |
for data in st.session_state.qas: | |
st.text( | |
"Question: " + data['question'] + '\n' + | |
"Answer: " + data['answer'] + '\n' + | |
"Section: " + data['section'] + '\n' + | |
"Paragraph #: " + data['para'] | |
) | |
if st.session_state.found_article == False: | |
get_article(retreiver) | |
else: | |
ask_questions(st.session_state.article_name, st.session_state.db) |