File size: 2,576 Bytes
9f23e0b
 
 
 
 
 
 
 
 
 
 
 
e67127b
9f23e0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7c1bdd
9f23e0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import streamlit as st
import wikipediaapi
from Article import Article
from QueryProcessor import QueryProcessor
from QuestionAnswer import QuestionAnswer

from transformers import AutoTokenizer, AutoModelForQuestionAnswering

model = AutoModelForQuestionAnswering.from_pretrained('Pennywise881/distilbert-base-uncased-finetuned-squad-v2')
tokenizer = AutoTokenizer.from_pretrained('Pennywise881/distilbert-base-uncased-finetuned-squad-v2')

st.write("""
    # Wiki Chat
""")

placeholder = st.empty()
wiki_wiki = wikipediaapi.Wikipedia('en')

if "found_article" not in st.session_state:
    st.session_state.page = 0
    st.session_state.found_article = False
    st.session_state.article = ''
    st.session_state.conversation = []
    st.session_state.article_data = {}


def get_article():
    article_name = placeholder.text_input('Enter the name of a Wikipedia article', '')

    if article_name:
       page = wiki_wiki.page(article_name)
       if page.exists():
        st.session_state.found_article = True
        st.session_state.article = article_name

        article = Article(article_name=article_name)
        st.session_state.article_data = article.get_article_data()
        
        ask_questions()
       else:
        st.write(f'Sorry, could not find Wikipedia article: {article_name}')

def ask_questions():
    question = placeholder.text_input(f"Ask questions about {st.session_state.article}", '')
    st.header("Questions and Answers:")

    if question:
        query_processor = QueryProcessor(
            question=question,
            section_texts=st.session_state.article_data['article_data'],
            N=st.session_state.article_data['num_docs'],
            avg_doc_len=st.session_state.article_data['avg_doc_len']
        )

        context = query_processor.get_context()

        data = {
            'question': question,
            'context': context
        }

        qa = QuestionAnswer(data, model, tokenizer, 'cpu')
        results = qa.get_results()
        
        answer = ''
        for r in results:
            answer += r['text']+", "

        answer = answer[:len(answer)-2]
        st.session_state.conversation.append({'question' : question, 'answer': answer})
        st.session_state.conversation.reverse()
        # print(results)

    if len(st.session_state.conversation) > 0:

        for data in st.session_state.conversation:
            st.text("Question: " + data['question'] + "\n" + "Answer: " + data['answer'] )


if st.session_state.found_article == False:
    get_article()

else:
    ask_questions()