Shreyas094 commited on
Commit
480bd35
·
verified ·
1 Parent(s): 652197b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -88
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import logging
3
  import asyncio
 
4
  import gradio as gr
5
  from huggingface_hub import InferenceClient
6
  from langchain.embeddings import HuggingFaceEmbeddings
@@ -55,7 +56,16 @@ def create_web_search_vectors(search_results):
55
  logging.info(f"Created vectors for {len(documents)} search results.")
56
  return FAISS.from_documents(documents, embed)
57
 
58
- async def get_response_with_search(query, system_prompt, model, use_embeddings, history=None, num_calls=3, temperature=0.2):
 
 
 
 
 
 
 
 
 
59
  search_results = duckduckgo_search(query)
60
 
61
  if not search_results:
@@ -66,14 +76,7 @@ async def get_response_with_search(query, system_prompt, model, use_embeddings,
66
  sources = [result['href'] for result in search_results if 'href' in result]
67
  source_list_str = "\n".join(sources)
68
 
69
- if use_embeddings:
70
- web_search_database = create_web_search_vectors(search_results)
71
- retriever = web_search_database.as_retriever(search_kwargs={"k": 5})
72
- relevant_docs = retriever.get_relevant_documents(query)
73
- context = "\n".join([doc.page_content for doc in relevant_docs])
74
- else:
75
- context = "\n".join([f"{result['title']}\n{result['body']}" for result in search_results])
76
-
77
  logging.info(f"Context created for query: {query}")
78
 
79
  user_message = f"""Using the following context from web search results:
@@ -81,9 +84,6 @@ async def get_response_with_search(query, system_prompt, model, use_embeddings,
81
 
82
  Write a detailed and complete research document that fulfills the following user request: '{query}'."""
83
 
84
- client = InferenceClient(model, token=huggingface_token)
85
- full_response = ""
86
-
87
  messages = [
88
  {"role": "system", "content": system_prompt},
89
  {"role": "user", "content": user_message}
@@ -92,50 +92,38 @@ Write a detailed and complete research document that fulfills the following user
92
  if history:
93
  messages = history + messages
94
 
95
- try:
96
- for call in range(num_calls):
97
- try:
98
- response_stream = client.chat_completion(
99
- messages=messages,
100
- max_tokens=6000,
101
- temperature=temperature,
102
- stream=True,
103
- top_p=0.8,
104
- )
105
-
106
- if response_stream is None:
107
- logging.error(f"API call {call + 1} returned None")
108
- yield "The API returned an empty response. Please try again.", ""
109
- continue
110
-
111
- for response in response_stream:
112
- if isinstance(response, dict) and "choices" in response:
113
- for choice in response["choices"]:
114
- if "delta" in choice and "content" in choice["delta"]:
115
- chunk = choice["delta"]["content"]
116
- full_response += chunk
117
- yield full_response, ""
118
- else:
119
- logging.error(f"Unexpected response format in API call {call + 1}: {response}")
120
-
121
- if full_response:
122
- break # If we got a valid response, exit the loop
123
-
124
- except Exception as e:
125
- logging.error(f"Error in API call {call + 1}: {str(e)}")
126
- if "422 Client Error" in str(e):
127
- logging.warning("Received 422 Client Error. Adjusting request parameters.")
128
- # You might want to adjust parameters here, e.g., reduce max_tokens
129
- yield f"An error occurred during API call {call + 1}. Retrying...", ""
130
-
131
- await asyncio.sleep(1) # 1 second delay between calls
132
 
133
- except asyncio.CancelledError:
134
- logging.warning("The operation was cancelled.")
135
- yield "The operation was cancelled. Please try again.", ""
136
- except Exception as e:
137
- logging.error(f"Unexpected error in get_response_with_search: {str(e)}")
138
- yield f"An unexpected error occurred: {str(e)}", ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  if not full_response:
141
  logging.warning("No response generated from the model")
@@ -143,18 +131,12 @@ Write a detailed and complete research document that fulfills the following user
143
  else:
144
  yield f"{full_response}\n\nSources:\n{source_list_str}", ""
145
 
146
- async def respond(message, system_prompt, history, model, temperature, num_calls, use_embeddings):
147
- logging.info(f"User Query: {message}")
148
- logging.info(f"Model Used: {model}")
149
- logging.info(f"Temperature: {temperature}")
150
- logging.info(f"Number of API Calls: {num_calls}")
151
- logging.info(f"Use Embeddings: {use_embeddings}")
152
- logging.info(f"System Prompt: {system_prompt}")
153
- logging.info(f"History: {history}") # Log the history for debugging
154
-
155
- # Convert gradio history to the format expected by get_response_with_search
156
  chat_history = []
157
- if history:
 
 
 
158
  for entry in history:
159
  if isinstance(entry, (list, tuple)) and len(entry) == 2:
160
  human, assistant = entry
@@ -164,10 +146,20 @@ async def respond(message, system_prompt, history, model, temperature, num_calls
164
  elif isinstance(entry, str):
165
  # If it's a string, assume it's a user message
166
  chat_history.append({"role": "user", "content": entry})
167
- # Ignore any other formats
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  try:
170
- full_response = ""
171
  async for main_content, sources in get_response_with_search(
172
  message,
173
  system_prompt,
@@ -177,16 +169,8 @@ async def respond(message, system_prompt, history, model, temperature, num_calls
177
  num_calls=num_calls,
178
  temperature=temperature
179
  ):
180
- if "error" in main_content.lower() or "no response" in main_content.lower():
181
- # If it's an error message, yield it as is
182
- yield main_content
183
- else:
184
- # Otherwise, yield only the new content
185
- new_content = main_content[len(full_response):]
186
- full_response = main_content
187
- yield new_content
188
-
189
- # Yield the sources as a separate message
190
  if sources:
191
  yield f"\n\nSources:\n{sources}"
192
 
@@ -213,16 +197,8 @@ css = """
213
  def create_gradio_interface():
214
  custom_placeholder = "Enter your question here for web search."
215
 
216
- async def wrapped_respond(*args):
217
- try:
218
- async for response in respond(*args):
219
- yield response
220
- except Exception as e:
221
- logging.error(f"Error in wrapped_respond: {str(e)}")
222
- yield f"An error occurred: {str(e)}"
223
-
224
  demo = gr.ChatInterface(
225
- fn=wrapped_respond, # Use the wrapped version
226
  additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=True, render=False),
227
  additional_inputs=[
228
  gr.Textbox(value=DEFAULT_SYSTEM_PROMPT, lines=6, label="System Prompt", placeholder="Enter your system prompt here"),
@@ -268,4 +244,4 @@ def create_gradio_interface():
268
 
269
  if __name__ == "__main__":
270
  demo = create_gradio_interface()
271
- demo.launch(share=True)
 
1
  import os
2
  import logging
3
  import asyncio
4
+ from typing import AsyncGenerator, Tuple
5
  import gradio as gr
6
  from huggingface_hub import InferenceClient
7
  from langchain.embeddings import HuggingFaceEmbeddings
 
56
  logging.info(f"Created vectors for {len(documents)} search results.")
57
  return FAISS.from_documents(documents, embed)
58
 
59
+ def create_context(search_results, use_embeddings, query):
60
+ if use_embeddings:
61
+ web_search_database = create_web_search_vectors(search_results)
62
+ retriever = web_search_database.as_retriever(search_kwargs={"k": 5})
63
+ relevant_docs = retriever.get_relevant_documents(query)
64
+ return "\n".join([doc.page_content for doc in relevant_docs])
65
+ else:
66
+ return "\n".join([f"{result['title']}\n{result['body']}" for result in search_results])
67
+
68
+ async def get_response_with_search(query: str, system_prompt: str, model: str, use_embeddings: bool, history=None, num_calls: int = 3, temperature: float = 0.2) -> AsyncGenerator[Tuple[str, str], None]:
69
  search_results = duckduckgo_search(query)
70
 
71
  if not search_results:
 
76
  sources = [result['href'] for result in search_results if 'href' in result]
77
  source_list_str = "\n".join(sources)
78
 
79
+ context = create_context(search_results, use_embeddings, query)
 
 
 
 
 
 
 
80
  logging.info(f"Context created for query: {query}")
81
 
82
  user_message = f"""Using the following context from web search results:
 
84
 
85
  Write a detailed and complete research document that fulfills the following user request: '{query}'."""
86
 
 
 
 
87
  messages = [
88
  {"role": "system", "content": system_prompt},
89
  {"role": "user", "content": user_message}
 
92
  if history:
93
  messages = history + messages
94
 
95
+ client = InferenceClient(model, token=huggingface_token)
96
+ full_response = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
+ for call in range(num_calls):
99
+ try:
100
+ response = await asyncio.to_thread(
101
+ client.chat_completion,
102
+ messages=messages,
103
+ max_tokens=6000,
104
+ temperature=temperature,
105
+ top_p=0.8,
106
+ )
107
+
108
+ if response is None or not isinstance(response, dict) or 'choices' not in response:
109
+ logging.error(f"API call {call + 1} returned an invalid response: {response}")
110
+ if call == num_calls - 1:
111
+ yield "The API returned an invalid response. Please try again later.", ""
112
+ continue
113
+
114
+ new_content = response['choices'][0]['message']['content']
115
+ full_response += new_content
116
+ yield full_response, ""
117
+
118
+ if full_response:
119
+ break # If we got a valid response, exit the loop
120
+
121
+ except Exception as e:
122
+ logging.error(f"Error in API call {call + 1}: {str(e)}")
123
+ if call == num_calls - 1:
124
+ yield f"An error occurred during API calls: {str(e)}. Please try again later.", ""
125
+
126
+ await asyncio.sleep(1) # 1 second delay between calls
127
 
128
  if not full_response:
129
  logging.warning("No response generated from the model")
 
131
  else:
132
  yield f"{full_response}\n\nSources:\n{source_list_str}", ""
133
 
134
+ def process_history(history):
 
 
 
 
 
 
 
 
 
135
  chat_history = []
136
+ if isinstance(history, str):
137
+ # If history is a string (like the system prompt), add it as a system message
138
+ chat_history.append({"role": "system", "content": history})
139
+ elif isinstance(history, list):
140
  for entry in history:
141
  if isinstance(entry, (list, tuple)) and len(entry) == 2:
142
  human, assistant = entry
 
146
  elif isinstance(entry, str):
147
  # If it's a string, assume it's a user message
148
  chat_history.append({"role": "user", "content": entry})
149
+ return chat_history
150
+
151
+ async def respond(message, system_prompt, history, model, temperature, num_calls, use_embeddings):
152
+ logging.info(f"User Query: {message}")
153
+ logging.info(f"Model Used: {model}")
154
+ logging.info(f"Temperature: {temperature}")
155
+ logging.info(f"Number of API Calls: {num_calls}")
156
+ logging.info(f"Use Embeddings: {use_embeddings}")
157
+ logging.info(f"System Prompt: {system_prompt}")
158
+ logging.info(f"History: {history}")
159
+
160
+ chat_history = process_history(history)
161
 
162
  try:
 
163
  async for main_content, sources in get_response_with_search(
164
  message,
165
  system_prompt,
 
169
  num_calls=num_calls,
170
  temperature=temperature
171
  ):
172
+ yield main_content
173
+
 
 
 
 
 
 
 
 
174
  if sources:
175
  yield f"\n\nSources:\n{sources}"
176
 
 
197
  def create_gradio_interface():
198
  custom_placeholder = "Enter your question here for web search."
199
 
 
 
 
 
 
 
 
 
200
  demo = gr.ChatInterface(
201
+ fn=respond,
202
  additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=True, render=False),
203
  additional_inputs=[
204
  gr.Textbox(value=DEFAULT_SYSTEM_PROMPT, lines=6, label="System Prompt", placeholder="Enter your system prompt here"),
 
244
 
245
  if __name__ == "__main__":
246
  demo = create_gradio_interface()
247
+ demo.launch(share=True)