wiki-chat / app.py
Pennywise881's picture
Update app.py
e67127b
raw
history blame
2.58 kB
import streamlit as st
import wikipediaapi
from Article import Article
from QueryProcessor import QueryProcessor
from QuestionAnswer import QuestionAnswer
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
model = AutoModelForQuestionAnswering.from_pretrained('Pennywise881/distilbert-base-uncased-finetuned-squad-v2')
tokenizer = AutoTokenizer.from_pretrained('Pennywise881/distilbert-base-uncased-finetuned-squad-v2')
st.write("""
# Wiki Chat
""")
placeholder = st.empty()
wiki_wiki = wikipediaapi.Wikipedia('en')
if "found_article" not in st.session_state:
st.session_state.page = 0
st.session_state.found_article = False
st.session_state.article = ''
st.session_state.conversation = []
st.session_state.article_data = {}
def get_article():
article_name = placeholder.text_input('Enter the name of a Wikipedia article', '')
if article_name:
page = wiki_wiki.page(article_name)
if page.exists():
st.session_state.found_article = True
st.session_state.article = article_name
article = Article(article_name=article_name)
st.session_state.article_data = article.get_article_data()
ask_questions()
else:
st.write(f'Sorry, could not find Wikipedia article: {article_name}')
def ask_questions():
question = placeholder.text_input(f"Ask questions about {st.session_state.article}", '')
st.header("Questions and Answers:")
if question:
query_processor = QueryProcessor(
question=question,
section_texts=st.session_state.article_data['article_data'],
N=st.session_state.article_data['num_docs'],
avg_doc_len=st.session_state.article_data['avg_doc_len']
)
context = query_processor.get_context()
data = {
'question': question,
'context': context
}
qa = QuestionAnswer(data, model, tokenizer, 'cpu')
results = qa.get_results()
answer = ''
for r in results:
answer += r['text']+", "
answer = answer[:len(answer)-2]
st.session_state.conversation.append({'question' : question, 'answer': answer})
st.session_state.conversation.reverse()
# print(results)
if len(st.session_state.conversation) > 0:
for data in st.session_state.conversation:
st.text("Question: " + data['question'] + "\n" + "Answer: " + data['answer'] )
if st.session_state.found_article == False:
get_article()
else:
ask_questions()