zamalali commited on
Commit
553b5fc
·
1 Parent(s): eb3e47a

Refactor conversation function to improve image handling and add descriptive comments

Browse files
Files changed (1) hide show
  1. app.py +29 -10
app.py CHANGED
@@ -204,6 +204,7 @@ def conversation(
204
  hf_token,
205
  model_path,
206
  ):
 
207
  if hf_token.strip() != "" and model_path.strip() != "":
208
  llm = HuggingFaceEndpoint(
209
  repo_id=model_path,
@@ -219,6 +220,7 @@ def conversation(
219
  huggingfacehub_api_token=os.getenv("P_HF_TOKEN", "None"),
220
  )
221
 
 
222
  text_collection = vectordb_client.get_collection(
223
  "text_db", embedding_function=sentence_transformer_ef
224
  )
@@ -226,23 +228,36 @@ def conversation(
226
  "image_db", embedding_function=sentence_transformer_ef
227
  )
228
 
 
229
  results = text_collection.query(
230
  query_texts=[msg], include=["documents"], n_results=num_context
231
  )["documents"][0]
 
 
232
  similar_images = image_collection.query(
233
  query_texts=[msg],
234
  include=["metadatas", "distances", "documents"],
235
  n_results=img_context,
236
  )
237
- img_links = [i["image"] for i in similar_images["metadatas"][0]]
238
-
239
- images_and_locs = [
240
- Image.open(io.BytesIO(base64.b64decode(i[1])))
241
- for i in zip(similar_images["distances"][0], img_links)
242
- ]
243
- img_desc = "\n".join(similar_images["documents"][0])
244
- if len(img_links) == 0:
245
- img_desc = "No Images Are Provided"
 
 
 
 
 
 
 
 
 
 
246
  template = """
247
  Context:
248
  {context}
@@ -258,11 +273,15 @@ def conversation(
258
  """
259
  prompt = PromptTemplate(template=template, input_variables=["context", "question"])
260
  context = "\n\n".join(results)
261
- # references = [gr.Textbox(i, visible=True, interactive=False) for i in results]
 
262
  response = llm(prompt.format(context=context, question=msg, images=img_desc))
 
 
263
  return history + [(msg, response)], results, images_and_locs
264
 
265
 
 
266
  def check_validity_and_llm(session_states):
267
  if session_states.get("processed", False) == True:
268
  return gr.Tabs(selected=2)
 
204
  hf_token,
205
  model_path,
206
  ):
207
+ # Initialize the LLM based on inputs
208
  if hf_token.strip() != "" and model_path.strip() != "":
209
  llm = HuggingFaceEndpoint(
210
  repo_id=model_path,
 
220
  huggingfacehub_api_token=os.getenv("P_HF_TOKEN", "None"),
221
  )
222
 
223
+ # Retrieve collections from vector database
224
  text_collection = vectordb_client.get_collection(
225
  "text_db", embedding_function=sentence_transformer_ef
226
  )
 
228
  "image_db", embedding_function=sentence_transformer_ef
229
  )
230
 
231
+ # Query text context
232
  results = text_collection.query(
233
  query_texts=[msg], include=["documents"], n_results=num_context
234
  )["documents"][0]
235
+
236
+ # Query image context
237
  similar_images = image_collection.query(
238
  query_texts=[msg],
239
  include=["metadatas", "distances", "documents"],
240
  n_results=img_context,
241
  )
242
+
243
+ # Initialize image links and descriptions
244
+ img_links = similar_images["metadatas"][0] if similar_images["metadatas"] else []
245
+ images_and_locs = []
246
+
247
+ for distance, link in zip(similar_images["distances"][0], img_links):
248
+ try:
249
+ img = Image.open(io.BytesIO(base64.b64decode(link["image"])))
250
+ caption = f"Distance: {distance:.2f}"
251
+ images_and_locs.append((img, caption))
252
+ except Exception as e:
253
+ print(f"Error decoding image: {e}")
254
+
255
+ # Handle case where no images are found
256
+ if not images_and_locs:
257
+ images_and_locs = [("path/to/placeholder/image.jpg", "No images found")]
258
+
259
+ # Prepare prompt for the LLM
260
+ img_desc = "\n".join(similar_images["documents"][0]) if images_and_locs else "No Images Are Provided"
261
  template = """
262
  Context:
263
  {context}
 
273
  """
274
  prompt = PromptTemplate(template=template, input_variables=["context", "question"])
275
  context = "\n\n".join(results)
276
+
277
+ # Generate response
278
  response = llm(prompt.format(context=context, question=msg, images=img_desc))
279
+
280
+ # Return updated history, text results, and image locations
281
  return history + [(msg, response)], results, images_and_locs
282
 
283
 
284
+
285
  def check_validity_and_llm(session_states):
286
  if session_states.get("processed", False) == True:
287
  return gr.Tabs(selected=2)