Muhammad Haris commited on
Commit
25abf8c
·
verified ·
1 Parent(s): 0ed6995

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -50
app.py CHANGED
@@ -1,63 +1,82 @@
1
- import gradio as gr
2
  import pandas as pd
 
3
  from sklearn.metrics.pairwise import cosine_similarity
4
- import numpy as np
5
- import re
6
- import os
7
  import gdown
8
- from sentence_transformers import SentenceTransformer
 
9
 
10
- # Download the file
11
  file_id = '1P3Nz6f3KG0m0kO_2pEfnVIhgP8Bvkl4v'
12
  url = f'https://drive.google.com/uc?id={file_id}'
13
  excel_file_path = os.path.join(os.path.expanduser("~"), 'medical_data.csv')
14
 
15
- gdown.download(url, excel_file_path, quiet=False)
16
-
17
- # Read the CSV file into a DataFrame using 'latin1' encoding
18
  try:
19
  medical_df = pd.read_csv(excel_file_path, encoding='utf-8')
20
  except UnicodeDecodeError:
21
  medical_df = pd.read_csv(excel_file_path, encoding='latin1')
22
 
23
- def remove_digits_with_dot(input_string):
24
- # Define a regex pattern to match digits with a dot at the beginning of the string
25
- pattern = re.compile(r'^\d+\.')
26
-
27
- # Use sub() method to replace the matched pattern with an empty string
28
- result_string = re.sub(pattern, '', input_string)
29
-
30
- return result_string
31
-
32
- medical_df["Questions"] = medical_df["Questions"].apply(remove_digits_with_dot)
33
-
34
- medical_df = medical_df[medical_df["Answers"].notna()]
35
-
36
- # Initialize SentenceTransformer model directly
37
- model_name_or_path = "hkunlp/instructor-large"
38
- model = SentenceTransformer(model_name_or_path)
39
-
40
- # Encode answers to create embeddings
41
- corpus = medical_df["Answers"].tolist()
42
- answer_embeddings = model.encode(corpus)
43
-
44
- def get_answer(query):
45
- # Encode query to get query embedding
46
- query_embedding = model.encode([query])
47
-
48
- # Compute cosine similarity between query embedding and answer embeddings
49
- similarities = cosine_similarity(query_embedding, answer_embeddings)
50
-
51
- # Get index of the answer with highest similarity
52
- retrieved_doc_id = np.argmax(similarities)
53
-
54
- # Retrieve corresponding question, answer, and references
55
- q = medical_df.iloc[retrieved_doc_id]["Questions"]
56
- a = medical_df.iloc[retrieved_doc_id]["Answers"]
57
- r = medical_df.iloc[retrieved_doc_id]["References"]
58
-
59
- return (q, a, r)
60
-
61
- # Gradio interface setup
62
- iface = gr.Interface(fn=get_answer, inputs=gr.inputs.Textbox(), outputs=["text", "text", "text"], title="Medical QA System")
63
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
  import pandas as pd
3
+ from sklearn.feature_extraction.text import TfidfVectorizer
4
  from sklearn.metrics.pairwise import cosine_similarity
5
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
6
+ from sentence_transformers import SentenceTransformer, util
7
+ import torch
8
  import gdown
9
+ import os
10
+
11
 
 
12
  file_id = '1P3Nz6f3KG0m0kO_2pEfnVIhgP8Bvkl4v'
13
  url = f'https://drive.google.com/uc?id={file_id}'
14
  excel_file_path = os.path.join(os.path.expanduser("~"), 'medical_data.csv')
15
 
16
+ # Read the CSV file
 
 
17
  try:
18
  medical_df = pd.read_csv(excel_file_path, encoding='utf-8')
19
  except UnicodeDecodeError:
20
  medical_df = pd.read_csv(excel_file_path, encoding='latin1')
21
 
22
+ # TF-IDF Vectorization
23
+ vectorizer = TfidfVectorizer(stop_words='english')
24
+ X_tfidf = vectorizer.fit_transform(medical_df['Questions'])
25
+
26
+ # Load pre-trained GPT-2 model and tokenizer
27
+ model_name = "sshleifer/tiny-gpt2"
28
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
29
+ model = GPT2LMHeadModel.from_pretrained(model_name)
30
+
31
+ # Load pre-trained Sentence Transformer model
32
+ sbert_model_name = "paraphrase-MiniLM-L6-v2"
33
+ sbert_model = SentenceTransformer(sbert_model_name)
34
+
35
+
36
+ # Function to answer medical questions using a combination of TF-IDF, LLM, and semantic similarity
37
+ def get_medical_response(question, vectorizer, X_tfidf, model, tokenizer, sbert_model, medical_df):
38
+ # TF-IDF Cosine Similarity
39
+ question_vector = vectorizer.transform([question])
40
+ tfidf_similarities = cosine_similarity(question_vector, X_tfidf).flatten()
41
+
42
+ # Find the most similar question using semantic similarity
43
+ question_embedding = sbert_model.encode(question, convert_to_tensor=True)
44
+ similarities = util.pytorch_cos_sim(question_embedding, sbert_model.encode(medical_df['Questions'].tolist(), convert_to_tensor=True)).flatten()
45
+ max_sim_index = similarities.argmax().item()
46
+
47
+ # LLM response generation
48
+ input_text = "DiBot: " + medical_df.iloc[max_sim_index]['Questions']
49
+ input_ids = tokenizer.encode(input_text, return_tensors="pt")
50
+ attention_mask = torch.ones(input_ids.shape, dtype=torch.long)
51
+ pad_token_id = tokenizer.eos_token_id
52
+ lm_output = model.generate(input_ids, max_length=150, num_return_sequences=1, no_repeat_ngram_size=2, attention_mask=attention_mask, pad_token_id=pad_token_id)
53
+ lm_generated_response = tokenizer.decode(lm_output[0], skip_special_tokens=True)
54
+
55
+ # Compare similarities and choose the best response
56
+ if tfidf_similarities.max() > 0.5:
57
+ tfidf_index = tfidf_similarities.argmax()
58
+ return medical_df.iloc[tfidf_index]['Answers']
59
+ else:
60
+ return lm_generated_response
61
+
62
+ # Streamlit UI
63
+ st.title("DiBot")
64
+
65
+ if "messages" not in st.session_state:
66
+ st.session_state.messages = []
67
+
68
+ for message in st.session_state.messages:
69
+ with st.chat_message(message["role"]):
70
+ st.markdown(message["content"])
71
+
72
+ user_input = st.chat_input("You:")
73
+
74
+ if user_input:
75
+ response = get_medical_response(user_input, vectorizer, X_tfidf, model, tokenizer, sbert_model, medical_df)
76
+ st.session_state.messages.append({"role": "user", "content": user_input})
77
+ st.session_state.messages.append({"role": "assistant", "content": response})
78
+
79
+ # Display the chat messages
80
+ for message in st.session_state.messages:
81
+ with st.chat_message(message["role"]):
82
+ st.markdown(message["content"])