import streamlit as st import os from Article import Article from VectorDB import VectorDB from QuestionAnswer import QuestionAnswer from transformers import AutoTokenizer, AutoModelForQuestionAnswering from sentence_transformers import models, SentenceTransformer reader = AutoModelForQuestionAnswering.from_pretrained('Pennywise881/distilbert-base-uncased-finetuned-squad-v2') tokenizer = AutoTokenizer.from_pretrained('Pennywise881/distilbert-base-uncased-finetuned-squad-v2') distilbert = models.Transformer("Pennywise881/distilbert-base-uncased-mnr-squadv2") pooler = models.Pooling( distilbert.get_word_embedding_dimension(), pooling_mode_mean_tokens=True ) retreiver = SentenceTransformer(modules=[distilbert, pooler]) if 'found_article' not in st.session_state: st.session_state.found_article = False st.session_state.article_name = '' st.session_state.db = None st.session_state.qas = [] st.write(""" # Wiki Chat V2 """) placeholder = st.empty() def get_article(retreiver): article_name = placeholder.text_input("Enter the name of a Wikipedia article") if article_name: article = Article() article_data = article.get_article_data(article_name=article_name) if len(article_data) > 0: API_KEY = os.environ['API_KEY'] db = VectorDB(retreiver=retreiver, API_KEY=API_KEY) db.upsert_data(article_data=article_data) ask_questions(article_name=article_name, db=db) st.session_state.found_article = True st.session_state.article_name = article_name st.session_state.db = db else: st.write(f'Sorry, could not find Wikipedia article: {article_name}') def ask_questions(article_name, db : VectorDB): question = placeholder.text_input(f"Ask questions about '{article_name}'", '') st.header("Questions and Answers:") if question: contexts = db.get_contexts(question.lower()) # print(contexts) data = { 'question': question.lower(), 'context': contexts['matches'][0]['metadata']['text'] } qa = QuestionAnswer(data, reader, tokenizer, 'cpu') results = qa.get_results() paragraph_index = contexts['matches'][0]['id'] section = contexts['matches'][0]['metadata']['section'] answer = '' for r in results: answer += r['text'] + ", " answer = answer[:len(answer) - 2] st.session_state.qas.append( { 'question': question, 'answer': answer, 'section': section, 'para': paragraph_index } ) if len(st.session_state.qas) > 0: for data in st.session_state.qas: st.text( "Question: " + data['question'] + '\n' + "Answer: " + data['answer'] + '\n' + "Section: " + data['section'] + '\n' + "Paragraph #: " + data['para'] ) if st.session_state.found_article == False: get_article(retreiver) else: ask_questions(st.session_state.article_name, st.session_state.db)