Muhammad Adnan commited on
Commit
ae6eb20
·
1 Parent(s): 3623388

Initial commit of Streamlit app

Browse files
Files changed (4) hide show
  1. app.py +147 -0
  2. data_ret.py +57 -0
  3. requirements.txt +8 -0
  4. similarity_search.py +94 -0
app.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import pipeline
3
+ from similarity_search import get_relevant_context # Import function from similarity_search.py
4
+ from bs4 import BeautifulSoup # For stripping HTML/XML tags
5
+ import spacy # Import spaCy for NLP tasks
6
+
7
+ # Load the spaCy model (make sure to download it first via 'python -m spacy download en_core_web_sm')
8
+ nlp = spacy.load("en_core_web_sm")
9
+
10
+ # Load the Roberta model for question answering
11
+ def load_qa_model():
12
+ print("Loading QA model...")
13
+ try:
14
+ qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2")
15
+ print("QA model loaded.")
16
+ return qa_model
17
+ except Exception as e:
18
+ print(f"Error loading QA model: {e}")
19
+ raise RuntimeError("Failed to load the QA model.")
20
+
21
+ # Function to clean the context text (remove HTML tags and optional stop words)
22
+ def clean_text(context, remove_stop_words=False):
23
+ # Remove HTML/XML tags
24
+ clean_context = BeautifulSoup(context, "html.parser").get_text()
25
+
26
+ if remove_stop_words:
27
+ stop_words = set(["the", "a", "an", "of", "and", "to", "in", "for", "on", "at", "by", "with", "about", "as", "from"])
28
+ clean_context = " ".join([word for word in clean_context.split() if word.lower() not in stop_words])
29
+
30
+ return clean_context
31
+
32
+ # Function to extract proper nouns or pronouns from the question for context retrieval
33
+ def extract_topic_from_question(question):
34
+ # Process the text with spaCy
35
+ doc = nlp(question)
36
+
37
+ # Define pronouns to exclude manually if necessary
38
+ excluded_pronouns = ['I', 'you', 'he', 'she', 'it', 'they', 'we', 'them', 'this', 'that', 'these', 'those']
39
+
40
+ # Extract proper nouns (PROPN) and pronouns (PRON), but exclude certain pronouns and stopwords
41
+ proper_nouns_or_pronouns = [
42
+ token.text for token in doc
43
+ if (
44
+ token.pos_ == 'PROPN' or token.pos_ == 'PRON') and token.text.lower() not in excluded_pronouns and not token.is_stop
45
+ ]
46
+
47
+ # If no proper nouns or pronouns are found, remove stopwords and return whatever is left
48
+ if not proper_nouns_or_pronouns:
49
+ remaining_tokens = [
50
+ token.text for token in doc
51
+ if not token.is_stop # Just remove stopwords, keep all other tokens
52
+ ]
53
+ return " ".join(remaining_tokens)
54
+
55
+ # Otherwise, return the proper nouns or pronouns
56
+ return " ".join(proper_nouns_or_pronouns)
57
+
58
+ # Inside the answer_question_with_context function, add debugging statements:
59
+ def answer_question_with_context(question, qa_model):
60
+ try:
61
+ print(question)
62
+ # Extract topic from question (proper nouns or pronouns)
63
+ topic = extract_topic_from_question(question)
64
+ print(f"Extracted topic (proper nouns or pronouns): {topic}" if topic else "No proper nouns or pronouns extracted.")
65
+
66
+ # Retrieve relevant context based on the extracted topic
67
+ context = get_relevant_context(question, topic)
68
+ print(f"Retrieved Context: {context}") # Debug: Show context result
69
+
70
+ if not context.strip():
71
+ return "No context found for answering.", ""
72
+
73
+ # Clean the context
74
+ context = clean_text(context, remove_stop_words=True)
75
+
76
+ # Use the QA model to extract an answer from the context
77
+ result = qa_model(question=question, context=context)
78
+ return result.get('answer', 'No answer found.'), context
79
+ except Exception as e:
80
+ print(f"Error during question answering: {e}") # Debug: Log error in terminal
81
+ return f"Error during question answering: {e}", ""
82
+
83
+ # Streamlit UI
84
+ def main():
85
+ st.title("RAG Question Answering with Context Retrieval")
86
+
87
+ # User input for the question
88
+ question = st.text_input("Enter your question:", "What is the capital of Italy?") # Default question
89
+
90
+ # Display a log update
91
+ log = st.empty()
92
+
93
+ # Button to get the answer
94
+ if st.button("Get Answer"):
95
+ if not question:
96
+ st.error("Please provide a question.")
97
+ else:
98
+ try:
99
+ # Display a loading spinner and log message for the QA model
100
+ log.text("Loading QA model...")
101
+ with st.spinner("Loading QA model... Please wait."):
102
+
103
+ # Try loading the QA model
104
+ qa_model = load_qa_model()
105
+
106
+ # Display log message for context retrieval
107
+ log.text("Retrieving context...")
108
+ with st.spinner("Retrieving context..."):
109
+
110
+ answer, context = answer_question_with_context(question, qa_model)
111
+
112
+ if not context.strip():
113
+ # If context is empty, let the user enter the context manually
114
+ st.warning("I couldn't find any relevant context for this question. Please enter it below:")
115
+ context = st.text_area("Enter your context here:", "", height=200, max_chars=1000)
116
+ if not context.strip():
117
+ context = "I couldn't find any relevant context, and you didn't provide one either. Maybe next time!"
118
+
119
+ # Display the answer and context
120
+ st.subheader("Answer:")
121
+ st.write(answer) # Show the final answer
122
+
123
+ # Display the context
124
+ st.subheader("Context Used for Answering:")
125
+ st.text_area("Context:", context, height=200, max_chars=1000, key="context_input", disabled=False) # Editable context box
126
+
127
+ except Exception as e:
128
+ st.error(f"An error occurred: {e}")
129
+ log.text(f"Error: {e}") # Log error in place
130
+
131
+ # Display information about the application
132
+ st.markdown("""
133
+ ### About the Application
134
+ This is a **Retrieval-Augmented Generation (RAG)** application that answers questions by dynamically retrieving context from a dataset. Here's how it works:
135
+
136
+ 1. **Dynamic Topic Extraction**: The application analyzes the user's question and extracts key topics (such as proper nouns or pronouns) to understand the context of the query.
137
+ 2. **Context Retrieval**: Based on the extracted topic, the app searches for the most relevant documents (a few hundred) in the dataset.
138
+ 3. **Answer Generation**: Using the retrieved context, an AI model (like RoBERTa) is used to generate the most accurate answer possible. The model combines the context with its internal knowledge to provide a robust and informed response.
139
+ 4. **Customization**: If the application doesn't find enough relevant context automatically, you can manually input additional context to improve the answer.
140
+
141
+ The application leverages **Roberta-based question-answering models** to generate answers based on the context retrieved. This helps provide more accurate, context-specific answers compared to traditional approaches that rely solely on pre-trained model knowledge.
142
+
143
+ **Dataset Used**: The application dynamically pulls relevant documents from a dataset (e.g., academic papers, FAQ pages, product manuals, etc.), helping answer user questions more effectively.
144
+ """)
145
+
146
+ if __name__ == "__main__":
147
+ main()
data_ret.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+
3
+ # Load the dataset (specify split as 'train' to load the training data)
4
+ dataset = load_dataset('tom-010/google_natural_questions_answerability', split='train')
5
+
6
+ # Function to filter based on a query/topic and return relevant data
7
+ def search_relevant_data(topic="Artificial Intelligence", max_words=100, top_n=100):
8
+ # Filter the dataset based on the presence of the topic in 'question', 'answer', or 'text' fields
9
+ filtered_data = dataset.filter(
10
+ lambda x: (
11
+ (x['question'] is not None and topic.lower() in x['question'].lower()) or
12
+ (x['answer'] is not None and topic.lower() in x['answer'].lower()) or
13
+ (x['text'] is not None and topic.lower() in x['text'].lower())
14
+ )
15
+ )
16
+
17
+ # Ensure we only select up to the available number of rows
18
+ #num_to_select = min(top_n, len(filtered_data)) # Choose the minimum of top_n and available data
19
+ #filtered_data = filtered_data.select(range(num_to_select)) # Select up to 'num_to_select' rows
20
+ filtered_data = filtered_data.select(range(min(top_n, len(filtered_data))))
21
+
22
+
23
+ # Create a list to store the relevant data
24
+ relevant_documents = []
25
+
26
+ # Display and store an excerpt of the answer for each relevant entry
27
+ for entry in filtered_data:
28
+ # Check the type of 'entry' first to ensure it's a dictionary
29
+ if isinstance(entry, dict):
30
+ question = entry.get('question', '') # Accessing the 'question' field safely
31
+ answer = entry.get('answer', '') # Accessing the 'answer' field safely
32
+ text = entry.get('text', '') # Accessing the 'text' field safely
33
+
34
+ # Only store the first 'max_words' words of the answer or text
35
+ answer_excerpt = ' '.join(answer.split()[:max_words]) if answer else ""
36
+ text_excerpt = ' '.join(text.split()[:max_words]) if text else ""
37
+
38
+ # Append relevant document information to the list
39
+ relevant_documents.append({
40
+ "question": question,
41
+ "answer": answer_excerpt,
42
+ "text": text_excerpt
43
+ })
44
+
45
+ # Debugging: Print a preview of the data (optional)
46
+ #print(f"Question: {question[:20]}...") # Print first 20 chars of the question
47
+ #print(f"Answer (first {max_words} words): {answer_excerpt[:20]}...") # Print first 20 words of the answer
48
+ #print(f"Text (first {max_words} words): {text_excerpt[:20]}...") # Print first 20 words of the text
49
+ #print("-" * 50)
50
+ else:
51
+ print("Unexpected entry format:", entry)
52
+
53
+ return relevant_documents # Return the list of relevant documents
54
+
55
+ # Sample search query
56
+ #relevant_data = search_relevant_data("vatican city") # Change to the desired query/topic
57
+ #print(f"Found {len(relevant_data)} relevant documents.")
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ streamlit==1.20.0
2
+ transformers==4.33.0
3
+ sentence-transformers==2.2.0
4
+ scipy==1.10.0
5
+ numpy==1.24.2
6
+ datasets==2.9.0
7
+ beautifulsoup4==4.12.0
8
+ spacy==3.5.0
similarity_search.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ from scipy.spatial.distance import cosine
3
+ import numpy as np
4
+ from data_ret import search_relevant_data # Assuming this function fetches the data from some source
5
+ import streamlit as st
6
+
7
+ # Load the Sentence Transformer model for similarity search
8
+ def load_similarity_model():
9
+ st.write("Loading similarity model...") # Show status on Streamlit
10
+ retriever_model = SentenceTransformer("all-mpnet-base-v2")
11
+ st.write("Similarity model loaded.")
12
+ return retriever_model
13
+
14
+ # Create embeddings for the retrieved documents
15
+ def create_embeddings(documents, model):
16
+ if not documents:
17
+ st.write("No documents provided for embedding.")
18
+ return np.array([]) # Return empty array if no documents
19
+
20
+ st.write(f"Creating embeddings for {len(documents)} documents...") # Show progress
21
+ embeddings = []
22
+
23
+ # Track progress of the embedding creation using Streamlit's progress bar
24
+ progress_bar = st.progress(0)
25
+ step = 1 / len(documents) # This ensures the progress bar value stays within [0.0, 1.0]
26
+
27
+ # Include 'text' in the document text along with 'question' and 'answer'
28
+ document_texts = [doc['question'] + " " + doc['answer'] + " " + doc.get('text', '') for doc in documents]
29
+
30
+ for i, doc_text in enumerate(document_texts):
31
+ embedding = model.encode(doc_text)
32
+ embeddings.append(embedding)
33
+ progress_bar.progress(i * step) # Update the progress bar within valid range
34
+
35
+ embeddings = np.array(embeddings)
36
+ st.write(f"Embeddings created with shape: {embeddings.shape}")
37
+ return embeddings
38
+
39
+ # Retrieve documents based on the question embedding
40
+ def retrieve_documents(question_embedding, document_embeddings, top_k=5):
41
+ if document_embeddings.size == 0:
42
+ st.write("No document embeddings available for retrieval.")
43
+ return []
44
+
45
+ st.write("Calculating similarities between question and documents...")
46
+ similarities = np.array([1 - cosine(question_embedding, doc_embedding) for doc_embedding in document_embeddings])
47
+
48
+ # Get indices of top K similarities (highest similarity first)
49
+ top_indices = similarities.argsort()[-top_k:][::-1] # Sort in descending order
50
+ return top_indices
51
+
52
+ # Main function to get the context from the most relevant documents based on topic and question
53
+ def get_relevant_context(question, topic):
54
+ try:
55
+ st.write("Searching for relevant documents based on the topic...")
56
+ relevant_documents = search_relevant_data(topic) # Use dynamic topic for search query
57
+
58
+ st.write(f"Found {len(relevant_documents)} relevant documents.")
59
+
60
+ if not relevant_documents:
61
+ return "No relevant documents found."
62
+
63
+ retriever_model = load_similarity_model() # Load the similarity model
64
+
65
+ # Create document embeddings and show progress
66
+ document_embeddings = create_embeddings(relevant_documents, retriever_model)
67
+
68
+ if document_embeddings.size == 0:
69
+ return "No embeddings created for relevant documents."
70
+
71
+ st.write("Generating question embedding and retrieving relevant documents...")
72
+ question_embedding = retriever_model.encode(question)
73
+ relevant_doc_indices = retrieve_documents(question_embedding, document_embeddings)
74
+
75
+ if len(relevant_doc_indices) == 0:
76
+ return "No relevant documents found after embedding."
77
+
78
+ # Extract context from the top relevant documents
79
+ contexts = []
80
+ for idx in relevant_doc_indices:
81
+ doc = relevant_documents[idx]
82
+ context = doc.get('answer', '') + " " + doc.get('text', '')
83
+ if context.strip():
84
+ contexts.append(context)
85
+
86
+ if not contexts:
87
+ return "No valid contexts available for answering."
88
+
89
+ # Return the combined context for question answering
90
+ return " ".join(contexts)
91
+
92
+ except Exception as e:
93
+ st.write(f"Error processing question: {str(e)}")
94
+ return f"Error: {str(e)}"