File size: 3,233 Bytes
8644233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239fed5
8644233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import streamlit as st
import os

from Article import Article
from VectorDB import VectorDB
from QuestionAnswer import QuestionAnswer

from transformers import AutoTokenizer, AutoModelForQuestionAnswering
from sentence_transformers import models, SentenceTransformer


reader = AutoModelForQuestionAnswering.from_pretrained('Pennywise881/distilbert-base-uncased-finetuned-squad-v2')
tokenizer = AutoTokenizer.from_pretrained('Pennywise881/distilbert-base-uncased-finetuned-squad-v2')

distilbert = models.Transformer("Pennywise881/distilbert-base-uncased-mnr-squadv2")
pooler = models.Pooling(
    distilbert.get_word_embedding_dimension(),
    pooling_mode_mean_tokens=True
)

retreiver = SentenceTransformer(modules=[distilbert, pooler])

if 'found_article' not in st.session_state:
    st.session_state.found_article = False
    st.session_state.article_name = ''
    st.session_state.db = None
    st.session_state.qas = []

st.write("""
    # Wiki Chat V2
""")
placeholder = st.empty()

def get_article(retreiver):
    article_name = placeholder.text_input("Enter the name of a Wikipedia article")
    
    if article_name:
        article = Article()
        article_data = article.get_article_data(article_name=article_name)

        if len(article_data) > 0:
            API_KEY = os.environ['API_KEY']
            db = VectorDB(retreiver=retreiver, API_KEY=API_KEY)
            db.upsert_data(article_data=article_data)
            ask_questions(article_name=article_name, db=db)

            st.session_state.found_article = True
            st.session_state.article_name = article_name
            st.session_state.db = db
        else:
            st.write(f'Sorry, could not find Wikipedia article: {article_name}')

def ask_questions(article_name, db : VectorDB):
    question = placeholder.text_input(f"Ask questions about '{article_name}'", '')
    st.header("Questions and Answers:")

    if question:
        contexts = db.get_contexts(question.lower())
        # print(contexts)
        
        data = {
            'question': question.lower(),
            'context': contexts['matches'][0]['metadata']['text']
        }
        qa = QuestionAnswer(data, reader, tokenizer, 'cpu')
        results = qa.get_results()    

        paragraph_index = contexts['matches'][0]['id']
        section = contexts['matches'][0]['metadata']['section']
        answer = ''
        for r in results:
            answer += r['text'] + ", "
        
        answer = answer[:len(answer) - 2]
        st.session_state.qas.append(
            {
                'question': question,
                'answer': answer,
                'section': section,
                'para': paragraph_index
            }
        )

        if len(st.session_state.qas) > 0:
            for data in st.session_state.qas:
                st.text(
                    "Question: " +  data['question'] + '\n' + 
                    "Answer: " +  data['answer'] + '\n' + 
                    "Section: " + data['section'] + '\n' + 
                    "Paragraph #: " + data['para']
                )

if st.session_state.found_article == False:
    get_article(retreiver)

else:
    ask_questions(st.session_state.article_name, st.session_state.db)