Spaces:
Runtime error
Runtime error
File size: 3,233 Bytes
8644233 239fed5 8644233 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import streamlit as st
import os
from Article import Article
from VectorDB import VectorDB
from QuestionAnswer import QuestionAnswer
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
from sentence_transformers import models, SentenceTransformer
reader = AutoModelForQuestionAnswering.from_pretrained('Pennywise881/distilbert-base-uncased-finetuned-squad-v2')
tokenizer = AutoTokenizer.from_pretrained('Pennywise881/distilbert-base-uncased-finetuned-squad-v2')
distilbert = models.Transformer("Pennywise881/distilbert-base-uncased-mnr-squadv2")
pooler = models.Pooling(
distilbert.get_word_embedding_dimension(),
pooling_mode_mean_tokens=True
)
retreiver = SentenceTransformer(modules=[distilbert, pooler])
if 'found_article' not in st.session_state:
st.session_state.found_article = False
st.session_state.article_name = ''
st.session_state.db = None
st.session_state.qas = []
st.write("""
# Wiki Chat V2
""")
placeholder = st.empty()
def get_article(retreiver):
article_name = placeholder.text_input("Enter the name of a Wikipedia article")
if article_name:
article = Article()
article_data = article.get_article_data(article_name=article_name)
if len(article_data) > 0:
API_KEY = os.environ['API_KEY']
db = VectorDB(retreiver=retreiver, API_KEY=API_KEY)
db.upsert_data(article_data=article_data)
ask_questions(article_name=article_name, db=db)
st.session_state.found_article = True
st.session_state.article_name = article_name
st.session_state.db = db
else:
st.write(f'Sorry, could not find Wikipedia article: {article_name}')
def ask_questions(article_name, db : VectorDB):
question = placeholder.text_input(f"Ask questions about '{article_name}'", '')
st.header("Questions and Answers:")
if question:
contexts = db.get_contexts(question.lower())
# print(contexts)
data = {
'question': question.lower(),
'context': contexts['matches'][0]['metadata']['text']
}
qa = QuestionAnswer(data, reader, tokenizer, 'cpu')
results = qa.get_results()
paragraph_index = contexts['matches'][0]['id']
section = contexts['matches'][0]['metadata']['section']
answer = ''
for r in results:
answer += r['text'] + ", "
answer = answer[:len(answer) - 2]
st.session_state.qas.append(
{
'question': question,
'answer': answer,
'section': section,
'para': paragraph_index
}
)
if len(st.session_state.qas) > 0:
for data in st.session_state.qas:
st.text(
"Question: " + data['question'] + '\n' +
"Answer: " + data['answer'] + '\n' +
"Section: " + data['section'] + '\n' +
"Paragraph #: " + data['para']
)
if st.session_state.found_article == False:
get_article(retreiver)
else:
ask_questions(st.session_state.article_name, st.session_state.db) |