File size: 6,891 Bytes
e96ad26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# -*- coding: utf-8 -*-
"""

Created on Sun Dec 19 21:10:27 2021



@author: Deepak.Reji

"""

import streamlit as st
from GoogleNews import GoogleNews

#import torch
#from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
from PIL import Image

#model_path = './covid_qa_distillbert'

#%%
@st.cache(allow_output_mutation=True)
def qna():
    model_name = 'shainahub/covid_qa_distillbert'
    nlp = pipeline('question-answering', model=model_name, tokenizer=model_name)
    
    return nlp

@st.cache(allow_output_mutation=True)
def answergen(context, question):
    """DistilBert"""   
    tokenizer = DistilBertTokenizer.from_pretrained(model_path,return_token_type_ids = True)
    model = DistilBertForQuestionAnswering.from_pretrained(model_path, return_dict=False)
    encoding = tokenizer.encode_plus(question, context)
    input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]
    start_scores, end_scores = model(torch.tensor([input_ids]),       attention_mask=torch.tensor([attention_mask]))
    ans_tokens = input_ids[torch.argmax(start_scores) : torch.argmax(end_scores)+1]
    answer_tokens = tokenizer.convert_ids_to_tokens(ans_tokens , skip_special_tokens=True)
    answer_tokens_to_string = tokenizer.convert_tokens_to_string(answer_tokens)
    
    return answer_tokens_to_string

def news():
    googlenews = GoogleNews()
    googlenews = GoogleNews(lang='en', region='US')
    googlenews = GoogleNews(period='7d')
    googlenews.get_news('Covid19')
    result=googlenews.result()
    return result

#%%
def main():
    st.sidebar.header("COVID-19 Question Answering (CO-QA) system")
    
    wallpaper = Image.open('covid-removebg-preview.png')
    wallpaper = wallpaper.resize((1400,700))      
    st.sidebar.image(wallpaper)

    st.sidebar.info("This project presents a COVID-19 Question Answering (CO-QA) system, which is a Web App that utilises AI and NLP techniques to answer questions about COVID-19 and post-COVID-19 from scientific articles related to COVID-19. The objective is to aid the medical community in addressing critical COVID-19-related questions.")
    
    st.sidebar.header("Author")
    st.sidebar.info("""This model is part of the research topic: To mitigate the COVID-19's effects on a variety of populations (with pre-existing chronic health problems and structural vulnerabilities). The first step in this regard is to mine the literature for COVID-19 related topics and to answer any questions related to COVID-19 

                    \n -Shaina Raza""")
    
    st.markdown("<div align='center'><br>"
                "<img src='https://img.shields.io/badge/Domain-Medical-blue'"
                 "alt='API stability' height='25'/>"
                 "<img src='https://img.shields.io/badge/Model-covid__qa__distillbert-red'"
                 "alt='API stability' height='25'/>"
                 "<img src='https://img.shields.io/badge/Web%20App-Streamlit-yellow'"
                 "alt='API stability' height='25'/></div>", unsafe_allow_html=True)
    
    st.write(" ")
    
    nlp = qna()
    select_input = st.radio("Select the input",
                            ('Select from the examples', 'Type in your questions'))
    
    if select_input == 'Select from the examples':
        options = st.selectbox('Select the preset examples',
                           ['', 'example1', 'example2', 'example3'])
    
        if options != "":
            if options == 'example1':
                context = 'Coronavirus disease 2019 (COVID-19) is a contagious disease caused by severe acute respiratory syndrome coronavirus 2 (SARS-CoV-2). The first known case was identified in Wuhan, China, in December 2019.[7] The disease has since spread worldwide, leading to an ongoing pandemic.'
                question = 'What is COVID-19?'
            
            if options == 'example2':
                context = 'The first known infections from SARS-CoV-2 were discovered in Wuhan, China. The original source of viral transmission to humans remains unclear, as does whether the virus became pathogenic before or after the spillover event.'
                question = 'Where was COVID-19 first discovered?'
            
            if options == 'example3':
                context = 'Long COVID, also known as post-COVID-19 syndrome, post-acute sequelae of COVID-19 (PASC), or chronic COVID syndrome (CCS) is a condition characterized by long-term sequelae appearing or persisting after the typical convalescence period of COVID-19. Long COVID can affect nearly every organ system, with sequelae including respiratory system disorders, nervous system and neurocognitive disorders, mental health disorders, metabolic disorders, cardiovascular disorders, gastrointestinal disorders, malaise, fatigue, musculoskeletal pain, and anemia. A wide range of symptoms are commonly reported, including fatigue, headaches, shortness of breath, anosmia (loss of smell), parosmia (distorted smell), muscle weakness, low fever and cognitive dysfunction.'
                question = 'What is Post-COVID syndrome?'
            
            st.subheader('🎲 Question')
            st.write(question)
            
            st.subheader('🎲 Context')
            st.write(context)
            
            st.subheader("🎲 Answer")
            try:
                #answer = answergen(context, question) 
                
                QA_input = {
                    'question': question,
                    'context': context
                }
                res = nlp(QA_input)
                answer = res['answer']
                st.success(answer)
            except:
                st.error("Sorry! Couldn't process the request")

    if select_input == 'Type in your questions':
        context = st.text_area("Enter your context or paragraph", "")
        question = st.text_input("Enter the Question", "")
         
        if st.button("Run QnA") and question != "":
            st.subheader("🎲 Answer")
            try:
                #answer = answergen(context, question)
                #if answer == "":
                QA_input = {
                    'question': question,
                    'context': context
                }
                res = nlp(QA_input)
                answer = res['answer']
                st.success(answer)
            except:
                st.error("Sorry! Couldn't process the request")
    
    st.header("Latest Covid News")
    latest_news = news()
    for i in latest_news[0:10]:
        st.write(i['title'])
        st.write(i['link'])
        st.write(" ")
                            
if __name__ == "__main__":
    main()