wiki-chat / app.py
Pennywise881's picture
Update app.py
f7c1bdd
raw
history blame
2.58 kB
import streamlit as st
import wikipediaapi
from Article import Article
from QueryProcessor import QueryProcessor
from QuestionAnswer import QuestionAnswer
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
model = AutoModelForQuestionAnswering.from_pretrained('Pennywise881/distilbert-base-uncased-finetuned-squad-v2')
tokenizer = AutoTokenizer.from_pretrained('Pennywise881/distilbert-base-uncased-finetuned-squad-v2')
st.write("""
# Wiki Q & A
""")
placeholder = st.empty()
wiki_wiki = wikipediaapi.Wikipedia('en')
if "found_article" not in st.session_state:
st.session_state.page = 0
st.session_state.found_article = False
st.session_state.article = ''
st.session_state.conversation = []
st.session_state.article_data = {}
def get_article():
article_name = placeholder.text_input('Enter the name of a Wikipedia article', '')
if article_name:
page = wiki_wiki.page(article_name)
if page.exists():
st.session_state.found_article = True
st.session_state.article = article_name
article = Article(article_name=article_name)
st.session_state.article_data = article.get_article_data()
ask_questions()
else:
st.write(f'Sorry, could not find Wikipedia article: {article_name}')
def ask_questions():
question = placeholder.text_input(f"Ask questions about {st.session_state.article}", '')
st.header("Questions and Answers:")
if question:
query_processor = QueryProcessor(
question=question,
section_texts=st.session_state.article_data['article_data'],
N=st.session_state.article_data['num_docs'],
avg_doc_len=st.session_state.article_data['avg_doc_len']
)
context = query_processor.get_context()
data = {
'question': question,
'context': context
}
qa = QuestionAnswer(data, model, tokenizer, 'cpu')
results = qa.get_results()
answer = ''
for r in results:
answer += r['text']+", "
answer = answer[:len(answer)-2]
st.session_state.conversation.append({'question' : question, 'answer': answer})
st.session_state.conversation.reverse()
# print(results)
if len(st.session_state.conversation) > 0:
for data in st.session_state.conversation:
st.text("Question: " + data['question'] + "\n" + "Answer: " + data['answer'] )
if st.session_state.found_article == False:
get_article()
else:
ask_questions()