Phoenix21 commited on
Commit
5ec7b71
·
verified ·
1 Parent(s): 6458905

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +216 -0
app.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Install necessary libraries in Colab
2
+ # !pip install datasets langchain_community smolagents chardet gradio pandas nltk sklearn
3
+
4
+ # Import required modules
5
+ import os
6
+ import getpass
7
+ import pandas as pd
8
+ import chardet
9
+ import re
10
+ from langchain.docstore.document import Document
11
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
12
+ from langchain_community.retrievers import BM25Retriever
13
+ # from smolagents import Tool, HfApiModel, CodeAgent
14
+ from smolagents import CodeAgent, HfApiModel, DuckDuckGoSearchTool, ManagedAgent
15
+ from smolagents.agents import ToolCallingAgent
16
+ from smolagents import Tool, HfApiModel, TransformersModel, LiteLLMModel
17
+ from typing import Optional
18
+ import gradio as gr
19
+ import logging
20
+ from nltk.corpus import words
21
+ from sklearn.feature_extraction.text import TfidfVectorizer
22
+ from sklearn.metrics.pairwise import cosine_similarity
23
+
24
+
25
+ if 'GROQ_API_KEY' not in os.environ or not os.environ['GROQ_API_KEY']:
26
+ os.environ['GROQ_API_KEY'] = getpass.getpass('Enter GROQ_API_KEY: ')
27
+ else:
28
+ print("GROQ_API_KEY is already set.")
29
+ # Set up logging
30
+ logging.basicConfig(level=logging.INFO)
31
+ logger = logging.getLogger(__name__)
32
+
33
+ # Load NLTK word list for valid word checks
34
+ try:
35
+ english_words = set(words.words())
36
+ except LookupError:
37
+ import nltk
38
+ nltk.download('words')
39
+ english_words = set(words.words())
40
+
41
+ # Define allowed topics for health and wellness
42
+ ALLOWED_TOPICS = [
43
+ "mental health",
44
+ "physical health",
45
+ "fitness",
46
+ "nutrition",
47
+ "exercise",
48
+ "mindfulness",
49
+ "sleep",
50
+ "stress management",
51
+ "wellness",
52
+ "relaxation",
53
+ "healthy lifestyle",
54
+ "self-care",
55
+ "meditation",
56
+ "diet",
57
+ "hydration",
58
+ "breathing techniques",
59
+ "yoga",
60
+ "stress relief",
61
+ "emotional health",
62
+ "spiritual health",
63
+ "healthy habits"
64
+ ]
65
+
66
+ def is_valid_input(query):
67
+ """
68
+ Validate the user's input question.
69
+ """
70
+ if not query or query.strip() == "":
71
+ return False, "Input cannot be empty. Please provide a meaningful question."
72
+
73
+ if len(query.strip()) < 2:
74
+ return False, "Input is too short. Please provide more context or details."
75
+
76
+ # Check for valid words
77
+ words_in_text = re.findall(r'\b\w+\b', query.lower())
78
+ recognized_words = [word for word in words_in_text if word in english_words]
79
+
80
+ if not recognized_words:
81
+ return False, "Input appears unclear. Please use valid words in your question."
82
+
83
+ return True, "Valid input."
84
+
85
+ def similarity_search(query, corpus, threshold=0.2):
86
+ """
87
+ Perform similarity search using TF-IDF and cosine similarity.
88
+ """
89
+ vectorizer = TfidfVectorizer()
90
+ tfidf_matrix = vectorizer.fit_transform(corpus + [query])
91
+ query_vector = tfidf_matrix[-1]
92
+ similarities = cosine_similarity(query_vector, tfidf_matrix[:-1]).flatten()
93
+ max_similarity = max(similarities)
94
+ if max_similarity >= threshold:
95
+ most_similar_idx = similarities.argmax()
96
+ return True, corpus[most_similar_idx], max_similarity
97
+ return False, None, max_similarity
98
+
99
+ # Load and process the AIChatbot.csv file
100
+ def load_csv(file_path):
101
+ """
102
+ Load and process a CSV file into a list of documents.
103
+ """
104
+ try:
105
+ with open(file_path, 'rb') as f:
106
+ result = chardet.detect(f.read())
107
+ encoding = result['encoding']
108
+ data = pd.read_csv(file_path, encoding=encoding)
109
+ questions = data['Question'].dropna().tolist()
110
+ documents = [
111
+ Document(page_content=row.to_string(index=False), metadata={"source": file_path})
112
+ for _, row in data.iterrows()
113
+ ]
114
+ logger.info(f"Loaded {len(documents)} documents from {file_path}")
115
+ return documents, questions
116
+ except Exception as e:
117
+ logger.error(f"Error loading CSV file: {e}")
118
+ return [], []
119
+
120
+ # Load the AIChatbot.csv file
121
+ file_path = "AIChatbot.csv" # Ensure this file is uploaded to your environment
122
+ source_docs, corpus_questions = load_csv(file_path)
123
+ if not source_docs:
124
+ raise ValueError(f"Failed to load documents from {file_path}. Please check the file.")
125
+
126
+ # Split documents into manageable chunks
127
+ text_splitter = RecursiveCharacterTextSplitter(
128
+ chunk_size=500,
129
+ chunk_overlap=50,
130
+ add_start_index=True,
131
+ strip_whitespace=True,
132
+ separators=["\n\n", "\n", ".", " ", ""],
133
+ )
134
+ docs_processed = text_splitter.split_documents(source_docs)
135
+ logger.info(f"Split documents into {len(docs_processed)} chunks.")
136
+
137
+ # Define the retriever tool
138
+ class RetrieverTool(Tool):
139
+ name = "retriever"
140
+ description = "Uses semantic search to retrieve the parts of chatbot documentation most relevant to the query."
141
+ inputs = {
142
+ "query": {
143
+ "type": "string",
144
+ "description": "The query to perform. Use an affirmative tone rather than a question."
145
+ }
146
+ }
147
+ output_type = "string"
148
+
149
+ def __init__(self, docs, **kwargs):
150
+ super().__init__(**kwargs)
151
+ self.retriever = BM25Retriever.from_documents(docs, k=10)
152
+
153
+ def forward(self, query: str) -> str:
154
+ assert isinstance(query, str), "Search query must be a string."
155
+ docs = self.retriever.invoke(query)
156
+ # Return only the content of the most relevant document
157
+ if docs:
158
+ return docs[0].page_content.strip()
159
+ else:
160
+ return "No relevant information found."
161
+
162
+ retriever_tool = RetrieverTool(docs_processed)
163
+
164
+ # Define the improved custom prompt
165
+ custom_prompt = """
166
+ You are a friendly and knowledgeable AI assistant for a daily wellness company. Your goal is to provide clear, concise, and actionable answers to the user's health and wellness-related questions. Use a warm, approachable tone to make the user feel at ease.
167
+
168
+ When answering:
169
+ 1. Focus on brevity without sacrificing accuracy or helpfulness.
170
+ 2. Highlight key points in an easy-to-understand manner.
171
+ 3. Include examples, tips, or short step-by-step guides where relevant.
172
+ 4. Format lists or steps using markdown for better readability (e.g., numbered lists, bullet points).
173
+ 5. Ensure your response is self-contained, engaging, and ends with a polite closing remark.
174
+
175
+ Answer each question in a similar concise, helpful, and friendly way.
176
+ """
177
+
178
+ # Define the agent using smolagents
179
+ model = LiteLLMModel("groq/llama3-8b-8192") # Ensure the model is available
180
+ agent = CodeAgent(
181
+ tools=[retriever_tool], model=model, max_iterations=4, verbose=True
182
+ )
183
+
184
+ # Gradio interface for interacting with the RAG pipeline
185
+ def gradio_interface(query):
186
+ try:
187
+ is_valid, message = is_valid_input(query)
188
+ if not is_valid:
189
+ return message
190
+
191
+ # Perform similarity search to verify the query's viability
192
+ similar, similar_question, similarity_score = similarity_search(query, corpus_questions, threshold=0.2)
193
+ if not similar:
194
+ return (
195
+ "I'm here to assist with health and wellness-related topics. "
196
+ "However, I couldn't find a closely related question in the dataset. "
197
+ "Please refine your query."
198
+ )
199
+
200
+ # Directly query the agent if the question is valid
201
+ return agent.run(f"{custom_prompt}\n\nQuestion: {query}").strip()
202
+ except Exception as e:
203
+ logger.error(f"Error during query processing: {e}")
204
+ return "**An error occurred while processing your request. Please try again later.**"
205
+
206
+ interface = gr.Interface(
207
+ fn=gradio_interface,
208
+ inputs=gr.Textbox(label="Enter your question", placeholder="e.g., How does box breathing help reduce anxiety?"),
209
+ outputs=gr.Markdown(label="Answer"),
210
+ title="AI Chatbot for Wellness",
211
+ description="Ask questions based on the AIChatbot.csv file. Focus on health and wellness topics.",
212
+ theme="compact"
213
+ )
214
+
215
+ if __name__ == "__main__":
216
+ interface.launch(debug=True)