Phoenix21 commited on
Commit
42a8215
·
1 Parent(s): 3e57ba8

updated app.py for complete and coherent response

Browse files
Files changed (1) hide show
  1. app.py +68 -141
app.py CHANGED
@@ -2,8 +2,6 @@ import os
2
  import logging
3
  import re
4
  from langchain.vectorstores import Chroma
5
- from langchain_core.output_parsers import StrOutputParser
6
- from langchain_core.runnables import RunnablePassthrough
7
  from langchain_huggingface import HuggingFaceEmbeddings
8
  from langchain.text_splitter import RecursiveCharacterTextSplitter
9
  from langchain_groq import ChatGroq
@@ -15,42 +13,35 @@ import gradio as gr
15
  import pandas as pd
16
  import json
17
 
18
- # Enable logging for debugging
19
  logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
21
 
22
- # Function to clean the API key
23
  def clean_api_key(key):
24
  return ''.join(c for c in key if ord(c) < 128)
25
 
26
- # Load the GROQ API key from environment variables (set as a secret in the Space)
27
  api_key = os.getenv("GROQ_API_KEY")
28
  if not api_key:
29
- logger.error("GROQ_API_KEY environment variable is not set. Please add it as a secret.")
30
  raise ValueError("GROQ_API_KEY environment variable is not set. Please add it as a secret.")
31
- api_key = clean_api_key(api_key).strip() # Clean and strip whitespace
32
 
33
- # Function to clean text by removing non-ASCII characters
34
  def clean_text(text):
35
  return text.encode("ascii", errors="ignore").decode()
36
 
37
- # Function to load and clean documents from multiple file formats
38
  def load_documents(file_paths):
39
  docs = []
40
  for file_path in file_paths:
41
  ext = os.path.splitext(file_path)[-1].lower()
42
  try:
43
  if ext == ".csv":
44
- # Handle CSV files
45
  with open(file_path, 'rb') as f:
46
  result = chardet.detect(f.read())
47
  encoding = result['encoding']
48
  data = pd.read_csv(file_path, encoding=encoding)
49
- for index, row in data.iterrows():
50
  content = clean_text(row.to_string())
51
  docs.append(Document(page_content=content, metadata={"source": file_path}))
52
  elif ext == ".json":
53
- # Handle JSON files
54
  with open(file_path, 'r', encoding='utf-8') as f:
55
  data = json.load(f)
56
  if isinstance(data, list):
@@ -61,7 +52,6 @@ def load_documents(file_paths):
61
  content = clean_text(json.dumps(data))
62
  docs.append(Document(page_content=content, metadata={"source": file_path}))
63
  elif ext == ".txt":
64
- # Handle TXT files
65
  with open(file_path, 'r', encoding='utf-8') as f:
66
  content = clean_text(f.read())
67
  docs.append(Document(page_content=content, metadata={"source": file_path}))
@@ -69,178 +59,115 @@ def load_documents(file_paths):
69
  logger.warning(f"Unsupported file format: {file_path}")
70
  except Exception as e:
71
  logger.error(f"Error processing file {file_path}: {e}")
72
- logger.debug("Exception details:", exc_info=True)
73
  return docs
74
 
75
- # Function to ensure the response ends with complete sentences
76
  def ensure_complete_sentences(text):
77
- # Use regex to find all complete sentences
78
  sentences = re.findall(r'[^.!?]*[.!?]', text)
79
  if sentences:
80
- # Join all complete sentences to form the complete answer
81
- return ' '.join(sentences).strip()
82
- return text # Return as is if no complete sentence is found
83
 
84
- # Function to check if input is valid
85
  def is_valid_input(text):
86
- """
87
- Checks if the input text is meaningful.
88
- Returns True if the text contains alphabetic characters and is of sufficient length.
89
- """
90
  if not text or text.strip() == "":
91
  return False
92
- # Regex to check for at least one alphabetic character
93
  if not re.search('[A-Za-z]', text):
94
  return False
95
- # Additional check: minimum length
96
  if len(text.strip()) < 5:
97
  return False
98
  return True
99
 
100
- # Initialize the LLM using ChatGroq with GROQ's API
101
  def initialize_llm(model, temperature, max_tokens):
102
- try:
103
- # Allocate a portion of tokens for the prompt
104
- prompt_allocation = int(max_tokens * 0.2)
105
- response_max_tokens = max_tokens - prompt_allocation
106
- if response_max_tokens <= 50:
107
- raise ValueError("max_tokens is too small to allocate for the response.")
108
-
109
- llm = ChatGroq(
110
- model=model,
111
- temperature=temperature,
112
- max_tokens=response_max_tokens,
113
- api_key=api_key
114
- )
115
- logger.info("LLM initialized successfully.")
116
- return llm
117
- except Exception as e:
118
- logger.error(f"Error initializing LLM: {e}")
119
- raise
120
-
121
- # Create the RAG pipeline
122
  def create_rag_pipeline(file_paths, model, temperature, max_tokens):
123
- try:
124
- llm = initialize_llm(model, temperature, max_tokens)
125
- docs = load_documents(file_paths)
126
- if not docs:
127
- logger.warning("No documents were loaded. Please check your file paths and formats.")
128
- return None, "No documents were loaded. Please check your file paths and formats."
129
-
130
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
131
- splits = text_splitter.split_documents(docs)
132
-
133
- # Initialize the embedding model
134
- embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
135
-
136
- # Use a temporary directory for Chroma vectorstore
137
- vectorstore = Chroma.from_documents(
138
- documents=splits,
139
- embedding=embedding_model,
140
- persist_directory="/tmp/chroma_db"
141
- )
142
- vectorstore.persist() # Save the database to disk
143
- logger.info("Vectorstore initialized and persisted successfully.")
144
-
145
- retriever = vectorstore.as_retriever()
146
-
147
- custom_prompt_template = PromptTemplate(
148
- input_variables=["context", "question"],
149
- template="""
150
- You are an AI assistant with expertise in daily wellness. Your aim is to provide detailed and comprehensive solutions regarding daily wellness topics without unnecessary verbosity.
151
- Context:
152
- {context}
153
- Question:
154
- {question}
155
- Provide a thorough and complete answer, including relevant examples and a suggested schedule. Ensure that the response does not end abruptly.
156
- """
157
- )
158
-
159
- rag_chain = RetrievalQA.from_chain_type(
160
- llm=llm,
161
- chain_type="stuff",
162
- retriever=retriever,
163
- chain_type_kwargs={"prompt": custom_prompt_template}
164
- )
165
- logger.info("RAG pipeline created successfully.")
166
- return rag_chain, "Pipeline created successfully."
167
- except Exception as e:
168
- logger.error(f"Error creating RAG pipeline: {e}")
169
- logger.debug("Exception details:", exc_info=True)
170
- return None, f"Error creating RAG pipeline: {e}"
171
-
172
- # Initialize the RAG pipeline once at startup
173
  file_paths = ['AIChatbot.csv']
174
  model = "llama3-8b-8192"
175
  temperature = 0.7
176
  max_tokens = 500
177
-
178
  rag_chain, message = create_rag_pipeline(file_paths, model, temperature, max_tokens)
179
- if rag_chain is None:
180
- logger.error("Failed to initialize RAG pipeline at startup.")
181
 
182
- # Function to answer questions with input validation and post-processing
183
  def answer_question(model, temperature, max_tokens, question):
184
- # Validate input
185
  if not is_valid_input(question):
186
- logger.info("Received invalid input from user.")
187
- return "Please provide a valid question or input containing meaningful text."
188
-
189
  if rag_chain is None:
190
- logger.error("RAG pipeline is not initialized.")
191
  return "The system is currently unavailable. Please try again later."
192
-
193
  try:
194
  answer = rag_chain.run(question)
195
- logger.info("Question answered successfully.")
196
- # Post-process to ensure the answer ends with complete sentences
197
  complete_answer = ensure_complete_sentences(answer)
198
  return complete_answer
199
  except Exception as e_inner:
200
- logger.error(f"Error during RAG pipeline execution: {e_inner}")
201
- logger.debug("Exception details:", exc_info=True)
202
- return f"Error during RAG pipeline execution: {e_inner}"
203
 
204
- # Gradio Interface (no feedback)
205
  def gradio_interface(model, temperature, max_tokens, question):
206
  return answer_question(model, temperature, max_tokens, question)
207
 
208
- # Define Gradio UI
209
  interface = gr.Interface(
210
  fn=gradio_interface,
211
  inputs=[
212
- gr.Textbox(
213
- label="Model Name",
214
- value=model,
215
- placeholder="e.g., llama3-8b-8192"
216
- ),
217
- gr.Slider(
218
- label="Temperature",
219
- minimum=0,
220
- maximum=1,
221
- step=0.01,
222
- value=temperature,
223
- info="Controls the randomness of the response. Higher values make output more random."
224
- ),
225
- gr.Slider(
226
- label="Max Tokens",
227
- minimum=200,
228
- maximum=2048,
229
- step=1,
230
- value=max_tokens,
231
- info="Determines the maximum number of tokens in the response."
232
- ),
233
- gr.Textbox(
234
- label="Question",
235
- placeholder="e.g., What is box breathing and how does it help reduce anxiety?"
236
- )
237
  ],
238
  outputs="text",
239
  title="Daily Wellness AI",
240
- description="Ask questions about daily wellness and get detailed solutions.",
241
  examples=[
242
  ["llama3-8b-8192", 0.7, 500, "What is box breathing and how does it help reduce anxiety?"],
243
- ["llama3-8b-8192", 0.6, 600, "Provide a daily wellness schedule incorporating box breathing techniques."]
244
  ],
245
  allow_flagging="never"
246
  )
 
2
  import logging
3
  import re
4
  from langchain.vectorstores import Chroma
 
 
5
  from langchain_huggingface import HuggingFaceEmbeddings
6
  from langchain.text_splitter import RecursiveCharacterTextSplitter
7
  from langchain_groq import ChatGroq
 
13
  import pandas as pd
14
  import json
15
 
 
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
 
19
  def clean_api_key(key):
20
  return ''.join(c for c in key if ord(c) < 128)
21
 
22
+ # Load the GROQ API key
23
  api_key = os.getenv("GROQ_API_KEY")
24
  if not api_key:
 
25
  raise ValueError("GROQ_API_KEY environment variable is not set. Please add it as a secret.")
26
+ api_key = clean_api_key(api_key).strip()
27
 
 
28
  def clean_text(text):
29
  return text.encode("ascii", errors="ignore").decode()
30
 
 
31
  def load_documents(file_paths):
32
  docs = []
33
  for file_path in file_paths:
34
  ext = os.path.splitext(file_path)[-1].lower()
35
  try:
36
  if ext == ".csv":
 
37
  with open(file_path, 'rb') as f:
38
  result = chardet.detect(f.read())
39
  encoding = result['encoding']
40
  data = pd.read_csv(file_path, encoding=encoding)
41
+ for _, row in data.iterrows():
42
  content = clean_text(row.to_string())
43
  docs.append(Document(page_content=content, metadata={"source": file_path}))
44
  elif ext == ".json":
 
45
  with open(file_path, 'r', encoding='utf-8') as f:
46
  data = json.load(f)
47
  if isinstance(data, list):
 
52
  content = clean_text(json.dumps(data))
53
  docs.append(Document(page_content=content, metadata={"source": file_path}))
54
  elif ext == ".txt":
 
55
  with open(file_path, 'r', encoding='utf-8') as f:
56
  content = clean_text(f.read())
57
  docs.append(Document(page_content=content, metadata={"source": file_path}))
 
59
  logger.warning(f"Unsupported file format: {file_path}")
60
  except Exception as e:
61
  logger.error(f"Error processing file {file_path}: {e}")
 
62
  return docs
63
 
 
64
  def ensure_complete_sentences(text):
 
65
  sentences = re.findall(r'[^.!?]*[.!?]', text)
66
  if sentences:
67
+ return ' '.join(s.strip() for s in sentences)
68
+ return text
 
69
 
 
70
  def is_valid_input(text):
 
 
 
 
71
  if not text or text.strip() == "":
72
  return False
 
73
  if not re.search('[A-Za-z]', text):
74
  return False
 
75
  if len(text.strip()) < 5:
76
  return False
77
  return True
78
 
 
79
  def initialize_llm(model, temperature, max_tokens):
80
+ prompt_allocation = int(max_tokens * 0.2)
81
+ response_max_tokens = max_tokens - prompt_allocation
82
+ if response_max_tokens <= 50:
83
+ raise ValueError("max_tokens too small.")
84
+ llm = ChatGroq(
85
+ model=model,
86
+ temperature=temperature,
87
+ max_tokens=response_max_tokens,
88
+ api_key=api_key
89
+ )
90
+ return llm
91
+
 
 
 
 
 
 
 
 
92
  def create_rag_pipeline(file_paths, model, temperature, max_tokens):
93
+ llm = initialize_llm(model, temperature, max_tokens)
94
+ docs = load_documents(file_paths)
95
+ if not docs:
96
+ return None, "No documents were loaded."
97
+
98
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
99
+ splits = text_splitter.split_documents(docs)
100
+
101
+ embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
102
+
103
+ vectorstore = Chroma.from_documents(
104
+ documents=splits,
105
+ embedding=embedding_model,
106
+ persist_directory="/tmp/chroma_db"
107
+ )
108
+ vectorstore.persist()
109
+
110
+ retriever = vectorstore.as_retriever()
111
+
112
+ custom_prompt_template = PromptTemplate(
113
+ input_variables=["context", "question"],
114
+ template="""
115
+ You are an AI assistant specialized in daily wellness. Provide a concise, thorough, and stand-alone answer to the user's question based on the given context. Include relevant examples or schedules where beneficial. The final answer should be coherent, self-contained, and end with a complete sentence.
116
+
117
+ Context:
118
+ {context}
119
+
120
+ Question:
121
+ {question}
122
+
123
+ Final Answer:
124
+ """
125
+ )
126
+
127
+ rag_chain = RetrievalQA.from_chain_type(
128
+ llm=llm,
129
+ chain_type="stuff",
130
+ retriever=retriever,
131
+ chain_type_kwargs={"prompt": custom_prompt_template}
132
+ )
133
+ return rag_chain, "Pipeline created successfully."
134
+
 
 
 
 
 
 
 
 
135
  file_paths = ['AIChatbot.csv']
136
  model = "llama3-8b-8192"
137
  temperature = 0.7
138
  max_tokens = 500
 
139
  rag_chain, message = create_rag_pipeline(file_paths, model, temperature, max_tokens)
 
 
140
 
 
141
  def answer_question(model, temperature, max_tokens, question):
 
142
  if not is_valid_input(question):
143
+ return "Please provide a valid, meaningful question."
 
 
144
  if rag_chain is None:
 
145
  return "The system is currently unavailable. Please try again later."
 
146
  try:
147
  answer = rag_chain.run(question)
 
 
148
  complete_answer = ensure_complete_sentences(answer)
149
  return complete_answer
150
  except Exception as e_inner:
151
+ logger.error(f"Error: {e_inner}")
152
+ return "An error occurred while processing your request."
 
153
 
 
154
  def gradio_interface(model, temperature, max_tokens, question):
155
  return answer_question(model, temperature, max_tokens, question)
156
 
 
157
  interface = gr.Interface(
158
  fn=gradio_interface,
159
  inputs=[
160
+ gr.Textbox(label="Model Name", value=model),
161
+ gr.Slider(label="Temperature", minimum=0, maximum=1, step=0.01, value=temperature),
162
+ gr.Slider(label="Max Tokens", minimum=200, maximum=2048, step=1, value=max_tokens),
163
+ gr.Textbox(label="Question", placeholder="e.g., What is box breathing and how does it help reduce anxiety?")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  ],
165
  outputs="text",
166
  title="Daily Wellness AI",
167
+ description="Ask questions about daily wellness and receive a concise, complete answer.",
168
  examples=[
169
  ["llama3-8b-8192", 0.7, 500, "What is box breathing and how does it help reduce anxiety?"],
170
+ ["llama3-8b-8192", 0.6, 600, "Give me a weekly fitness schedule incorporating mindfulness exercises."]
171
  ],
172
  allow_flagging="never"
173
  )