wiki-chat-v2 / app.py
Pennywise881's picture
Update app.py
239fed5
raw
history blame
3.23 kB
import streamlit as st
import os
from Article import Article
from VectorDB import VectorDB
from QuestionAnswer import QuestionAnswer
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
from sentence_transformers import models, SentenceTransformer
reader = AutoModelForQuestionAnswering.from_pretrained('Pennywise881/distilbert-base-uncased-finetuned-squad-v2')
tokenizer = AutoTokenizer.from_pretrained('Pennywise881/distilbert-base-uncased-finetuned-squad-v2')
distilbert = models.Transformer("Pennywise881/distilbert-base-uncased-mnr-squadv2")
pooler = models.Pooling(
distilbert.get_word_embedding_dimension(),
pooling_mode_mean_tokens=True
)
retreiver = SentenceTransformer(modules=[distilbert, pooler])
if 'found_article' not in st.session_state:
st.session_state.found_article = False
st.session_state.article_name = ''
st.session_state.db = None
st.session_state.qas = []
st.write("""
# Wiki Chat V2
""")
placeholder = st.empty()
def get_article(retreiver):
article_name = placeholder.text_input("Enter the name of a Wikipedia article")
if article_name:
article = Article()
article_data = article.get_article_data(article_name=article_name)
if len(article_data) > 0:
API_KEY = os.environ['API_KEY']
db = VectorDB(retreiver=retreiver, API_KEY=API_KEY)
db.upsert_data(article_data=article_data)
ask_questions(article_name=article_name, db=db)
st.session_state.found_article = True
st.session_state.article_name = article_name
st.session_state.db = db
else:
st.write(f'Sorry, could not find Wikipedia article: {article_name}')
def ask_questions(article_name, db : VectorDB):
question = placeholder.text_input(f"Ask questions about '{article_name}'", '')
st.header("Questions and Answers:")
if question:
contexts = db.get_contexts(question.lower())
# print(contexts)
data = {
'question': question.lower(),
'context': contexts['matches'][0]['metadata']['text']
}
qa = QuestionAnswer(data, reader, tokenizer, 'cpu')
results = qa.get_results()
paragraph_index = contexts['matches'][0]['id']
section = contexts['matches'][0]['metadata']['section']
answer = ''
for r in results:
answer += r['text'] + ", "
answer = answer[:len(answer) - 2]
st.session_state.qas.append(
{
'question': question,
'answer': answer,
'section': section,
'para': paragraph_index
}
)
if len(st.session_state.qas) > 0:
for data in st.session_state.qas:
st.text(
"Question: " + data['question'] + '\n' +
"Answer: " + data['answer'] + '\n' +
"Section: " + data['section'] + '\n' +
"Paragraph #: " + data['para']
)
if st.session_state.found_article == False:
get_article(retreiver)
else:
ask_questions(st.session_state.article_name, st.session_state.db)