Nymbo commited on
Commit
4fa442d
·
verified ·
1 Parent(s): 4db9e4f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +195 -358
app.py CHANGED
@@ -6,41 +6,24 @@ import base64
6
  from PIL import Image
7
  import io
8
 
9
- # Import smolagents components
10
- from smolagents import CodeAgent, Tool
11
- from smolagents.models import InferenceClientModel as SmolInferenceClientModel # Alias to avoid conflict
12
 
13
  ACCESS_TOKEN = os.getenv("HF_TOKEN")
14
  print("Access token loaded.")
15
 
16
- # --- Smolagents Setup for Image Generation ---
17
- print("Initializing smolagents components for image generation...")
18
  try:
19
  image_generation_tool = Tool.from_space(
20
- "black-forest-labs/FLUX.1-schnell", # The Space ID of the image generation tool
21
  name="image_generator",
22
- description="Generates an image from a textual prompt. Use this tool if the user asks to 'generate an image of X', 'draw X', 'create a picture of X', or similar requests for visual content based on a description.",
23
- # Ensure the HF_TOKEN is available to gradio-client if the space is private or requires auth
24
- token=ACCESS_TOKEN if ACCESS_TOKEN and ACCESS_TOKEN.strip() != "" else None
25
  )
26
  print("Image generation tool loaded successfully.")
27
-
28
- # Initialize a model for the CodeAgent. This can be a simpler/faster model
29
- # as it's mainly for orchestrating the tool call.
30
- # Using a default InferenceClientModel from smolagents
31
- smol_agent_model = SmolInferenceClientModel(token=ACCESS_TOKEN if ACCESS_TOKEN and ACCESS_TOKEN.strip() != "" else None)
32
- print(f"Smolagent model initialized with: {smol_agent_model.model_id if hasattr(smol_agent_model, 'model_id') else 'default'}")
33
-
34
- image_agent = CodeAgent(
35
- tools=[image_generation_tool],
36
- model=smol_agent_model,
37
- verbosity_level=1 # Set to 0 for less verbose agent logging, 1 for info, 2 for debug
38
- )
39
- print("Image generation agent initialized successfully.")
40
  except Exception as e:
41
- print(f"Error initializing smolagents components: {e}")
42
- image_agent = None
43
- # --- End Smolagents Setup ---
44
 
45
  # Function to encode image to base64
46
  def encode_image(image_path):
@@ -64,7 +47,7 @@ def encode_image(image_path):
64
 
65
  # Encode to base64
66
  buffered = io.BytesIO()
67
- image.save(buffered, format="JPEG")
68
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
69
  print("Image encoded successfully")
70
  return img_str
@@ -73,9 +56,9 @@ def encode_image(image_path):
73
  return None
74
 
75
  def respond(
76
- message,
77
- image_files, # Changed parameter name and structure
78
- history: list[tuple[str, str]],
79
  system_message,
80
  max_tokens,
81
  temperature,
@@ -88,9 +71,9 @@ def respond(
88
  model_search_term,
89
  selected_model
90
  ):
91
- print(f"Received message: {message}")
92
- print(f"Received {len(image_files) if image_files else 0} images")
93
- print(f"History: {history}")
94
  print(f"System message: {system_message}")
95
  print(f"Max tokens: {max_tokens}, Temperature: {temperature}, Top-P: {top_p}")
96
  print(f"Frequency Penalty: {frequency_penalty}, Seed: {seed}")
@@ -100,136 +83,106 @@ def respond(
100
  print(f"Model search term: {model_search_term}")
101
  print(f"Selected model from radio: {selected_model}")
102
 
103
- # --- Agent-based Image Generation ---
104
- if message.startswith("/generate_image"):
105
- if image_agent is None:
106
- yield "Image generation agent is not initialized. Please check server logs."
107
- return
 
 
 
 
108
 
109
- prompt_for_agent = message.replace("/generate_image", "").strip()
110
- if not prompt_for_agent:
111
- yield "Please provide a prompt for image generation. Usage: /generate_image <your prompt>"
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  return
113
 
114
- print(f"Image generation requested with prompt: {prompt_for_agent}")
115
  try:
116
- # Agent run is blocking and returns the final result
117
- # Ensure the image_agent's model also has a token if needed for its operations (though it's for orchestration)
118
- agent_response = image_agent.run(prompt_for_agent)
119
-
120
- if isinstance(agent_response, str) and agent_response.lower().startswith("error"):
121
- yield f"Agent error: {agent_response}"
122
- elif hasattr(agent_response, 'to_string'): # Check if it's an AgentImage or similar
123
- image_path = agent_response.to_string() # This is a local path to the generated image
124
- print(f"Agent returned image path: {image_path}")
125
- # Gradio's chatbot can display images if the content is a file path string
126
- # or a tuple (filepath, alt_text)
127
- yield image_path
128
- else:
129
- yield f"Agent returned an unexpected response: {str(agent_response)}"
130
  return
131
  except Exception as e:
132
- print(f"Error running image agent: {e}")
133
- yield f"Error generating image: {str(e)}"
134
  return
135
- # --- End Agent-based Image Generation ---
 
 
136
 
137
- # Determine which token to use for text generation
138
- token_to_use = custom_api_key if custom_api_key.strip() != "" else ACCESS_TOKEN
139
-
140
- if custom_api_key.strip() != "":
141
- print("USING CUSTOM API KEY: BYOK token provided by user is being used for authentication")
142
- else:
143
- print("USING DEFAULT API KEY: Environment variable HF_TOKEN is being used for authentication")
144
-
145
- # Initialize the Inference Client with the provider and appropriate token
146
  client = InferenceClient(token=token_to_use, provider=provider)
147
- print(f"Hugging Face Inference Client initialized with {provider} provider for text generation.")
148
 
149
- # Convert seed to None if -1 (meaning random)
150
  if seed == -1:
151
  seed = None
152
 
153
- # Create multimodal content if images are present
154
- if image_files and len(image_files) > 0:
155
- user_content = []
156
- if message and message.strip():
157
- user_content.append({
158
- "type": "text",
159
- "text": message
160
- })
161
- for img_path in image_files: # Assuming image_files contains paths from MultimodalTextbox
162
- if img_path is not None:
163
  try:
164
- encoded_image = encode_image(img_path) # img_path is already a path from MultimodalTextbox
165
  if encoded_image:
166
- user_content.append({
167
  "type": "image_url",
168
- "image_url": {
169
- "url": f"data:image/jpeg;base64,{encoded_image}"
170
- }
171
  })
172
  except Exception as e:
173
- print(f"Error encoding image: {e}")
174
- else:
175
- # Text-only message
176
- user_content = message
 
 
177
 
178
- # Prepare messages in the format expected by the API
179
- messages = [{"role": "system", "content": system_message}]
180
- print("Initial messages array constructed.")
181
 
182
- # Add conversation history to the context
183
- for val in history:
184
- user_part = val[0]
185
- assistant_part = val[1]
186
-
187
- # Handle user messages (could be text or image markdown)
188
- if user_part:
189
- if isinstance(user_part, str) and user_part.startswith("![Image]("):
190
- # This is an image path from a previous agent generation
191
- # or a user upload represented as markdown
192
- history_image_path = user_part.replace("![Image](", "").replace(")", "")
193
- encoded_history_image = encode_image(history_image_path)
194
- if encoded_history_image:
195
- messages.append({"role": "user", "content": [{
196
- "type": "image_url",
197
- "image_url": {"url": f"data:image/jpeg;base64,{encoded_history_image}"}
198
- }]})
199
- elif isinstance(user_part, tuple) and len(user_part) == 2: # Multimodal input from user
200
- history_content_list = []
201
- if user_part[0]: # Text part
202
- history_content_list.append({"type": "text", "text": user_part[0]})
203
- for img_hist_path in user_part[1]: # List of image paths
204
- encoded_img_hist = encode_image(img_hist_path)
205
- if encoded_img_hist:
206
- history_content_list.append({
207
- "type": "image_url",
208
- "image_url": {"url": f"data:image/jpeg;base64,{encoded_img_hist}"}
209
- })
210
- if history_content_list:
211
- messages.append({"role": "user", "content": history_content_list})
212
- else: # Regular text message
213
- messages.append({"role": "user", "content": user_part})
214
- print(f"Added user message to context (type: {type(user_part)})")
215
 
216
- if assistant_part:
217
- messages.append({"role": "assistant", "content": assistant_part})
218
- print(f"Added assistant message to context: {assistant_part}")
 
 
 
 
219
 
220
- # Append the latest user message
221
- messages.append({"role": "user", "content": user_content})
222
- print(f"Latest user message appended (content type: {type(user_content)})")
223
 
224
- # Determine which model to use, prioritizing custom_model if provided
225
  model_to_use = custom_model.strip() if custom_model.strip() != "" else selected_model
226
- print(f"Model selected for inference: {model_to_use}")
227
 
228
- # Start with an empty string to build the response as tokens stream in
229
- response = ""
230
- print(f"Sending request to {provider} provider.")
231
 
232
- # Prepare parameters for the chat completion request
233
  parameters = {
234
  "max_tokens": max_tokens,
235
  "temperature": temperature,
@@ -240,58 +193,50 @@ def respond(
240
  if seed is not None:
241
  parameters["seed"] = seed
242
 
243
- # Use the InferenceClient for making the request
244
  try:
245
- # Create a generator for the streaming response
246
  stream = client.chat_completion(
247
  model=model_to_use,
248
- messages=messages,
249
  stream=True,
250
  **parameters
251
  )
252
 
253
- print("Received tokens: ", end="", flush=True)
254
 
255
- # Process the streaming response
256
  for chunk in stream:
257
  if hasattr(chunk, 'choices') and len(chunk.choices) > 0:
258
- # Extract the content from the response
259
  if hasattr(chunk.choices[0], 'delta') and hasattr(chunk.choices[0].delta, 'content'):
260
  token_text = chunk.choices[0].delta.content
261
  if token_text:
262
  print(token_text, end="", flush=True)
263
- response += token_text
264
- yield response
265
 
266
  print()
267
  except Exception as e:
268
- print(f"Error during inference: {e}")
269
- response += f"\nError: {str(e)}"
270
- yield response
271
 
272
- print("Completed response generation.")
273
 
274
- # Function to validate provider selection based on BYOK
275
  def validate_provider(api_key, provider):
276
  if not api_key.strip() and provider != "hf-inference":
277
  return gr.update(value="hf-inference")
278
  return gr.update(value=provider)
279
 
280
- # GRADIO UI
281
  with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
282
- # Create the chatbot component
283
  chatbot = gr.Chatbot(
284
  height=600,
285
  show_copy_button=True,
286
- placeholder="Select a model and begin chatting. Use '/generate_image your prompt' to create images.",
287
  layout="panel",
288
- show_share_button=True # Added for ease of sharing if deployed
289
  )
290
  print("Chatbot interface created.")
291
 
292
- # Multimodal textbox for messages (combines text and file uploads)
293
  msg = gr.MultimodalTextbox(
294
- placeholder="Type a message or upload images... (e.g., /generate_image a cat wearing a hat)",
295
  show_label=False,
296
  container=False,
297
  scale=12,
@@ -300,226 +245,138 @@ with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
300
  sources=["upload"]
301
  )
302
 
303
- # Create accordion for settings
304
  with gr.Accordion("Settings", open=False):
305
- # System message
306
  system_message_box = gr.Textbox(
307
- value="You are a helpful AI assistant that can understand images and text. If asked to generate an image, use the image_generator tool.",
308
  placeholder="You are a helpful assistant.",
309
  label="System Prompt"
310
  )
311
 
312
- # Generation parameters
313
  with gr.Row():
314
  with gr.Column():
315
- max_tokens_slider = gr.Slider(
316
- minimum=1,
317
- maximum=4096,
318
- value=512,
319
- step=1,
320
- label="Max tokens"
321
- )
322
-
323
- temperature_slider = gr.Slider(
324
- minimum=0.1,
325
- maximum=4.0,
326
- value=0.7,
327
- step=0.1,
328
- label="Temperature"
329
- )
330
-
331
- top_p_slider = gr.Slider(
332
- minimum=0.1,
333
- maximum=1.0,
334
- value=0.95,
335
- step=0.05,
336
- label="Top-P"
337
- )
338
-
339
  with gr.Column():
340
- frequency_penalty_slider = gr.Slider(
341
- minimum=-2.0,
342
- maximum=2.0,
343
- value=0.0,
344
- step=0.1,
345
- label="Frequency Penalty"
346
- )
347
-
348
- seed_slider = gr.Slider(
349
- minimum=-1,
350
- maximum=65535,
351
- value=-1,
352
- step=1,
353
- label="Seed (-1 for random)"
354
- )
355
 
356
- # Provider selection
357
- providers_list = [
358
- "hf-inference", # Default Hugging Face Inference
359
- "cerebras", # Cerebras provider
360
- "together", # Together AI
361
- "sambanova", # SambaNova
362
- "novita", # Novita AI
363
- "cohere", # Cohere
364
- "fireworks-ai", # Fireworks AI
365
- "hyperbolic", # Hyperbolic
366
- "nebius", # Nebius
367
- ]
368
-
369
- provider_radio = gr.Radio(
370
- choices=providers_list,
371
- value="hf-inference",
372
- label="Inference Provider",
373
- )
374
-
375
- # New BYOK textbox
376
- byok_textbox = gr.Textbox(
377
- value="",
378
- label="BYOK (Bring Your Own Key)",
379
- info="Enter a custom Hugging Face API key here. When empty, only 'hf-inference' provider can be used.",
380
- placeholder="Enter your Hugging Face API token",
381
- type="password" # Hide the API key for security
382
- )
383
-
384
- # Custom model box
385
- custom_model_box = gr.Textbox(
386
- value="",
387
- label="Custom Model",
388
- info="(Optional) Provide a custom Hugging Face model path. Overrides any selected featured model.",
389
- placeholder="meta-llama/Llama-3.3-70B-Instruct"
390
- )
391
 
392
- # Model search
393
- model_search_box = gr.Textbox(
394
- label="Filter Models",
395
- placeholder="Search for a featured model...",
396
- lines=1
397
- )
398
-
399
- # Featured models list
400
- # Updated to include multimodal models
401
  models_list = [
402
- "meta-llama/Llama-3.2-11B-Vision-Instruct",
403
- "meta-llama/Llama-3.3-70B-Instruct",
404
- "meta-llama/Llama-3.1-70B-Instruct",
405
- "meta-llama/Llama-3.0-70B-Instruct",
406
- "meta-llama/Llama-3.2-3B-Instruct",
407
- "meta-llama/Llama-3.2-1B-Instruct",
408
- "meta-llama/Llama-3.1-8B-Instruct",
409
- "NousResearch/Hermes-3-Llama-3.1-8B",
410
- "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
411
- "mistralai/Mistral-Nemo-Instruct-2407",
412
- "mistralai/Mixtral-8x7B-Instruct-v0.1",
413
- "mistralai/Mistral-7B-Instruct-v0.3",
414
- "mistralai/Mistral-7B-Instruct-v0.2",
415
- "Qwen/Qwen3-235B-A22B",
416
- "Qwen/Qwen3-32B",
417
- "Qwen/Qwen2.5-72B-Instruct",
418
- "Qwen/Qwen2.5-3B-Instruct",
419
- "Qwen/Qwen2.5-0.5B-Instruct",
420
- "Qwen/QwQ-32B",
421
- "Qwen/Qwen2.5-Coder-32B-Instruct",
422
- "microsoft/Phi-3.5-mini-instruct",
423
- "microsoft/Phi-3-mini-128k-instruct",
424
- "microsoft/Phi-3-mini-4k-instruct",
425
  ]
426
-
427
- featured_model_radio = gr.Radio(
428
- label="Select a model below",
429
- choices=models_list,
430
- value="meta-llama/Llama-3.2-11B-Vision-Instruct", # Default to a multimodal model
431
- interactive=True
432
- )
433
 
434
  gr.Markdown("[View all Text-to-Text models](https://huggingface.co/models?inference_provider=all&pipeline_tag=text-generation&sort=trending) | [View all multimodal models](https://huggingface.co/models?inference_provider=all&pipeline_tag=image-text-to-text&sort=trending)")
435
 
436
- # Chat history state
437
  chat_history = gr.State([])
438
 
439
- # Function to filter models
440
  def filter_models(search_term):
441
  print(f"Filtering models with search term: {search_term}")
442
  filtered = [m for m in models_list if search_term.lower() in m.lower()]
443
  print(f"Filtered models: {filtered}")
444
- return gr.update(choices=filtered if filtered else models_list, value=filtered[0] if filtered else models_list[0])
445
 
446
-
447
- # Function to set custom model from radio
448
  def set_custom_model_from_radio(selected):
449
  print(f"Featured model selected: {selected}")
450
  return selected
451
 
452
- # Function for the chat interface
453
- def user(user_message_obj, history):
454
- print(f"User message object received: {user_message_obj}")
455
-
456
- text_content = user_message_obj.get("text", "").strip()
457
- files = user_message_obj.get("files", []) # files is a list of temp file paths
458
 
459
- if not text_content and not files:
460
- print("Empty message (no text, no files), skipping history update.")
461
- return history # Or raise gr.Error("Please enter a message or upload an image.")
462
 
463
- # Represent uploaded images in history using markdown syntax for local paths
464
- # For multimodal models, the actual file path from 'files' will be used in 'respond'
465
- display_message_parts = []
 
 
 
 
 
 
466
  if text_content:
467
- display_message_parts.append(text_content)
 
 
 
 
 
 
 
 
 
 
 
468
 
469
- processed_files_for_history = []
470
- if files:
471
- for file_path_obj in files:
472
- # Gradio's MultimodalTextbox provides file objects with a .name attribute for the path
473
- file_path = file_path_obj.name if hasattr(file_path_obj, 'name') else str(file_path_obj)
474
- display_message_parts.append(f"![Uploaded Image]({file_path})")
475
- processed_files_for_history.append(file_path) # Store the actual path for 'respond'
476
-
477
- # For history, we store the text and a list of file paths
478
- # The 'respond' function will then re-encode these for the API
479
- history_entry_user = (text_content, processed_files_for_history)
480
- history.append([history_entry_user, None])
481
- print(f"History updated with user input: {history_entry_user}")
482
  return history
483
 
484
- # Define bot response function
485
  def bot(history, system_msg, max_tokens, temperature, top_p, freq_penalty, seed, provider, api_key, custom_model, search_term, selected_model):
486
- if not history or len(history) == 0 or history[-1][0] is None:
487
- print("No user message in history to process for bot.")
488
- yield history
489
  return
490
 
491
- user_input_tuple = history[-1][0] # This is now (text, [file_paths])
492
- text_message_from_history = user_input_tuple[0]
493
- image_files_from_history = user_input_tuple[1]
494
-
495
- print(f"Bot processing: text='{text_message_from_history}', images={image_files_from_history}")
496
 
497
- history[-1][1] = ""
 
 
 
 
 
 
 
 
 
 
498
 
499
- # Pass text and image file paths to respond function
500
  for response_chunk in respond(
501
- message=text_message_from_history,
502
- image_files=image_files_from_history,
503
- history=history[:-1], # Pass history excluding the current user turn
504
- system_message=system_msg,
505
- max_tokens=max_tokens,
506
- temperature=temperature,
507
- top_p=top_p,
508
- frequency_penalty=freq_penalty,
509
- seed=seed,
510
- provider=provider,
511
- custom_api_key=api_key,
512
- custom_model=custom_model,
513
- model_search_term=search_term,
514
- selected_model=selected_model
515
  ):
516
- history[-1][1] = response_chunk
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517
  yield history
518
-
519
- # Event handlers
520
  msg.submit(
521
  user,
522
- [msg, chatbot], # msg is MultimodalTextboxOutput(text=str, files=List[FileData])
523
  [chatbot],
524
  queue=False
525
  ).then(
@@ -529,45 +386,25 @@ with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
529
  model_search_box, featured_model_radio],
530
  [chatbot]
531
  ).then(
532
- lambda: gr.update(value={"text": "", "files": []}), # Clear MultimodalTextbox
533
  None,
534
  [msg]
535
  )
536
 
537
- # Connect the model filter to update the radio choices
538
- model_search_box.change(
539
- fn=filter_models,
540
- inputs=model_search_box,
541
- outputs=featured_model_radio
542
- )
543
  print("Model search box change event linked.")
544
 
545
- # Connect the featured model radio to update the custom model box
546
- featured_model_radio.change(
547
- fn=set_custom_model_from_radio,
548
- inputs=featured_model_radio,
549
- outputs=custom_model_box
550
- )
551
  print("Featured model radio button change event linked.")
552
 
553
- # Connect the BYOK textbox to validate provider selection
554
- byok_textbox.change(
555
- fn=validate_provider,
556
- inputs=[byok_textbox, provider_radio],
557
- outputs=provider_radio
558
- )
559
  print("BYOK textbox change event linked.")
560
 
561
- # Also validate provider when the radio changes to ensure consistency
562
- provider_radio.change(
563
- fn=validate_provider,
564
- inputs=[byok_textbox, provider_radio],
565
- outputs=provider_radio
566
- )
567
  print("Provider radio button change event linked.")
568
 
569
  print("Gradio interface initialized.")
570
 
571
  if __name__ == "__main__":
572
  print("Launching the demo application.")
573
- demo.launch(show_api=False) # show_api=False for cleaner public interface, True for debugging
 
6
  from PIL import Image
7
  import io
8
 
9
+ # Import smolagents Tool
10
+ from smolagents import Tool
 
11
 
12
  ACCESS_TOKEN = os.getenv("HF_TOKEN")
13
  print("Access token loaded.")
14
 
15
+ # Initialize the image generation tool
16
+ # This can be defined globally as it doesn't change per request
17
  try:
18
  image_generation_tool = Tool.from_space(
19
+ "black-forest-labs/FLUX.1-schnell",
20
  name="image_generator",
21
+ description="Generates an image from a text prompt. Use it when the user asks to 'generate an image of ...' or 'draw a picture of ...'. The input should be the descriptive prompt for the image."
 
 
22
  )
23
  print("Image generation tool loaded successfully.")
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  except Exception as e:
25
+ print(f"Error loading image generation tool: {e}")
26
+ image_generation_tool = None
 
27
 
28
  # Function to encode image to base64
29
  def encode_image(image_path):
 
47
 
48
  # Encode to base64
49
  buffered = io.BytesIO()
50
+ image.save(buffered, format="JPEG")
51
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
52
  print("Image encoded successfully")
53
  return img_str
 
56
  return None
57
 
58
  def respond(
59
+ message_text, # Changed from 'message' to be explicit about text part
60
+ image_files, # This will be a list of paths from gr.MultimodalTextbox
61
+ history: list[list[Any, str | None]], # History can now contain complex user messages
62
  system_message,
63
  max_tokens,
64
  temperature,
 
71
  model_search_term,
72
  selected_model
73
  ):
74
+ print(f"Received message text: {message_text}")
75
+ print(f"Received {len(image_files) if image_files else 0} image files: {image_files}")
76
+ # print(f"History: {history}") # Can be very verbose
77
  print(f"System message: {system_message}")
78
  print(f"Max tokens: {max_tokens}, Temperature: {temperature}, Top-P: {top_p}")
79
  print(f"Frequency Penalty: {frequency_penalty}, Seed: {seed}")
 
83
  print(f"Model search term: {model_search_term}")
84
  print(f"Selected model from radio: {selected_model}")
85
 
86
+ # Determine which token to use
87
+ token_to_use = custom_api_key if custom_api_key.strip() != "" else ACCESS_TOKEN
88
+
89
+ if custom_api_key.strip() != "":
90
+ print("USING CUSTOM API KEY: BYOK token provided by user is being used for authentication")
91
+ else:
92
+ print("USING DEFAULT API KEY: Environment variable HF_TOKEN is being used for authentication")
93
+
94
+ user_text_message_lower = message_text.lower() if message_text else ""
95
 
96
+ image_keywords = ["generate image", "draw a picture of", "create an image of", "make an image of"]
97
+ is_image_generation_request = any(keyword in user_text_message_lower for keyword in image_keywords)
98
+
99
+ if is_image_generation_request and image_generation_tool:
100
+ print("Image generation request detected.")
101
+ image_prompt = message_text
102
+ for keyword in image_keywords:
103
+ if keyword in user_text_message_lower:
104
+ # Find the keyword in the original case-sensitive message text to split
105
+ keyword_start_index = user_text_message_lower.find(keyword)
106
+ image_prompt = message_text[keyword_start_index + len(keyword):].strip()
107
+ break
108
+
109
+ print(f"Extracted image prompt: {image_prompt}")
110
+ if not image_prompt:
111
+ yield {"type": "text", "content": "Please provide a description for the image you want to generate."}
112
  return
113
 
 
114
  try:
115
+ generated_image_path = image_generation_tool(prompt=image_prompt)
116
+ print(f"Image generated by tool, path: {generated_image_path}")
117
+ yield {"type": "image", "path": str(generated_image_path)} # Ensure path is string
 
 
 
 
 
 
 
 
 
 
 
118
  return
119
  except Exception as e:
120
+ print(f"Error during image generation tool call: {e}")
121
+ yield {"type": "text", "content": f"Sorry, I couldn't generate the image. Error: {str(e)}"}
122
  return
123
+ elif is_image_generation_request and not image_generation_tool:
124
+ yield {"type": "text", "content": "Image generation tool is not available or failed to load."}
125
+ return
126
 
127
+ # If not an image generation request, proceed with text/multimodal LLM call
128
+ print("Proceeding with LLM call (text or multimodal).")
 
 
 
 
 
 
 
129
  client = InferenceClient(token=token_to_use, provider=provider)
130
+ print(f"Hugging Face Inference Client initialized with {provider} provider.")
131
 
 
132
  if seed == -1:
133
  seed = None
134
 
135
+ # Prepare messages for LLM
136
+ llm_user_content = []
137
+ if message_text and message_text.strip():
138
+ llm_user_content.append({"type": "text", "text": message_text})
139
+
140
+ if image_files: # image_files is a list of paths from gr.MultimodalTextbox
141
+ for img_path in image_files:
142
+ if img_path:
 
 
143
  try:
144
+ encoded_image = encode_image(img_path) # img_path is already a path
145
  if encoded_image:
146
+ llm_user_content.append({
147
  "type": "image_url",
148
+ "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}
 
 
149
  })
150
  except Exception as e:
151
+ print(f"Error encoding image for LLM: {e}")
152
+
153
+ if not llm_user_content: # Should not happen if user() function filters empty messages
154
+ print("No content for LLM, aborting.")
155
+ yield {"type": "text", "content": "Please provide some input."}
156
+ return
157
 
158
+ messages_for_llm = [{"role": "system", "content": system_message}]
159
+ print("Initial messages array constructed for LLM.")
 
160
 
161
+ for val in history: # history item is [user_content_list, assistant_response_str_or_dict]
162
+ user_content_list_hist = val[0]
163
+ assistant_response_hist = val[1]
164
+
165
+ if user_content_list_hist:
166
+ # user_content_list_hist is already in the correct format (list of dicts)
167
+ messages_for_llm.append({"role": "user", "content": user_content_list_hist})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
+ if assistant_response_hist:
170
+ # Assistant response could be text or an image dict from a previous tool call
171
+ if isinstance(assistant_response_hist, dict) and assistant_response_hist.get("type") == "image":
172
+ messages_for_llm.append({"role": "assistant", "content": [{"type": "text", "text": f"Assistant previously displayed image: {assistant_response_hist.get('path')}"}]})
173
+ elif isinstance(assistant_response_hist, str):
174
+ messages_for_llm.append({"role": "assistant", "content": assistant_response_hist})
175
+ # Else, if it's a dict but not an image type we understand for history, we might skip or log an error
176
 
177
+ messages_for_llm.append({"role": "user", "content": llm_user_content})
178
+ # print(f"Full messages_for_llm: {messages_for_llm}") # Can be very verbose
 
179
 
 
180
  model_to_use = custom_model.strip() if custom_model.strip() != "" else selected_model
181
+ print(f"Model selected for LLM inference: {model_to_use}")
182
 
183
+ response_text = ""
184
+ print(f"Sending request to {provider} provider for LLM.")
 
185
 
 
186
  parameters = {
187
  "max_tokens": max_tokens,
188
  "temperature": temperature,
 
193
  if seed is not None:
194
  parameters["seed"] = seed
195
 
 
196
  try:
 
197
  stream = client.chat_completion(
198
  model=model_to_use,
199
+ messages=messages_for_llm,
200
  stream=True,
201
  **parameters
202
  )
203
 
204
+ print("Received LLM tokens: ", end="", flush=True)
205
 
 
206
  for chunk in stream:
207
  if hasattr(chunk, 'choices') and len(chunk.choices) > 0:
 
208
  if hasattr(chunk.choices[0], 'delta') and hasattr(chunk.choices[0].delta, 'content'):
209
  token_text = chunk.choices[0].delta.content
210
  if token_text:
211
  print(token_text, end="", flush=True)
212
+ response_text += token_text
213
+ yield {"type": "text", "content": response_text}
214
 
215
  print()
216
  except Exception as e:
217
+ print(f"Error during LLM inference: {e}")
218
+ response_text += f"\nError: {str(e)}"
219
+ yield {"type": "text", "content": response_text}
220
 
221
+ print("Completed LLM response generation.")
222
 
 
223
  def validate_provider(api_key, provider):
224
  if not api_key.strip() and provider != "hf-inference":
225
  return gr.update(value="hf-inference")
226
  return gr.update(value=provider)
227
 
 
228
  with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
 
229
  chatbot = gr.Chatbot(
230
  height=600,
231
  show_copy_button=True,
232
+ placeholder="Select a model and begin chatting. Now supports multiple inference providers and multimodal inputs. Try 'generate image of a cat playing chess'.",
233
  layout="panel",
234
+ bubble_full_width=False
235
  )
236
  print("Chatbot interface created.")
237
 
 
238
  msg = gr.MultimodalTextbox(
239
+ placeholder="Type a message or upload images...",
240
  show_label=False,
241
  container=False,
242
  scale=12,
 
245
  sources=["upload"]
246
  )
247
 
 
248
  with gr.Accordion("Settings", open=False):
 
249
  system_message_box = gr.Textbox(
250
+ value="You are a helpful AI assistant that can understand images and text. If asked to generate an image, respond by saying you will call the image_generator tool.",
251
  placeholder="You are a helpful assistant.",
252
  label="System Prompt"
253
  )
254
 
 
255
  with gr.Row():
256
  with gr.Column():
257
+ max_tokens_slider = gr.Slider(minimum=1, maximum=4096, value=512, step=1, label="Max tokens")
258
+ temperature_slider = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
259
+ top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-P")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  with gr.Column():
261
+ frequency_penalty_slider = gr.Slider(minimum=-2.0, maximum=2.0, value=0.0, step=0.1, label="Frequency Penalty")
262
+ seed_slider = gr.Slider(minimum=-1, maximum=65535, value=-1, step=1, label="Seed (-1 for random)")
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
+ providers_list = ["hf-inference", "cerebras", "together", "sambanova", "novita", "cohere", "fireworks-ai", "hyperbolic", "nebius"]
265
+ provider_radio = gr.Radio(choices=providers_list, value="hf-inference", label="Inference Provider")
266
+ byok_textbox = gr.Textbox(value="", label="BYOK (Bring Your Own Key)", info="Enter a custom Hugging Face API key here. When empty, only 'hf-inference' provider can be used.", placeholder="Enter your Hugging Face API token", type="password")
267
+ custom_model_box = gr.Textbox(value="", label="Custom Model", info="(Optional) Provide a custom Hugging Face model path. Overrides any selected featured model.", placeholder="meta-llama/Llama-3.3-70B-Instruct")
268
+ model_search_box = gr.Textbox(label="Filter Models", placeholder="Search for a featured model...", lines=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
 
 
 
 
 
 
 
 
 
270
  models_list = [
271
+ "meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.3-70B-Instruct", "meta-llama/Llama-3.1-70B-Instruct",
272
+ "meta-llama/Llama-3.0-70B-Instruct", "meta-llama/Llama-3.2-3B-Instruct", "meta-llama/Llama-3.2-1B-Instruct",
273
+ "meta-llama/Llama-3.1-8B-Instruct", "NousResearch/Hermes-3-Llama-3.1-8B", "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
274
+ "mistralai/Mistral-Nemo-Instruct-2407", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.3",
275
+ "mistralai/Mistral-7B-Instruct-v0.2", "Qwen/Qwen3-235B-A22B", "Qwen/Qwen3-32B", "Qwen/Qwen2.5-72B-Instruct",
276
+ "Qwen/Qwen2.5-3B-Instruct", "Qwen/Qwen2.5-0.5B-Instruct", "Qwen/QwQ-32B", "Qwen/Qwen2.5-Coder-32B-Instruct",
277
+ "microsoft/Phi-3.5-mini-instruct", "microsoft/Phi-3-mini-128k-instruct", "microsoft/Phi-3-mini-4k-instruct",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  ]
279
+ featured_model_radio = gr.Radio(label="Select a model below", choices=models_list, value="meta-llama/Llama-3.2-11B-Vision-Instruct", interactive=True)
 
 
 
 
 
 
280
 
281
  gr.Markdown("[View all Text-to-Text models](https://huggingface.co/models?inference_provider=all&pipeline_tag=text-generation&sort=trending) | [View all multimodal models](https://huggingface.co/models?inference_provider=all&pipeline_tag=image-text-to-text&sort=trending)")
282
 
 
283
  chat_history = gr.State([])
284
 
 
285
  def filter_models(search_term):
286
  print(f"Filtering models with search term: {search_term}")
287
  filtered = [m for m in models_list if search_term.lower() in m.lower()]
288
  print(f"Filtered models: {filtered}")
289
+ return gr.update(choices=filtered)
290
 
 
 
291
  def set_custom_model_from_radio(selected):
292
  print(f"Featured model selected: {selected}")
293
  return selected
294
 
295
+ def user(user_multimodal_input, history):
296
+ print(f"User input (raw from gr.MultimodalTextbox): {user_multimodal_input}")
 
 
 
 
297
 
298
+ text_content = user_multimodal_input.get("text", "").strip()
299
+ files = user_multimodal_input.get("files", []) # These are temp file paths from Gradio
 
300
 
301
+ if not text_content and not files:
302
+ print("Empty input, skipping history append.")
303
+ # Optionally, could raise gr.Error("Please enter a message or upload an image.")
304
+ # For now, let's allow the bot to respond if history is not empty,
305
+ # or do nothing if history is also empty.
306
+ return history
307
+
308
+ # Prepare content for history: a list of dicts for multimodal display
309
+ history_user_entry_content = []
310
  if text_content:
311
+ history_user_entry_content.append({"type": "text", "text": text_content})
312
+
313
+ for file_path_obj in files: # file_path_obj is a FileData object from Gradio
314
+ if file_path_obj and hasattr(file_path_obj, 'name') and file_path_obj.name:
315
+ # Gradio's Chatbot can display images directly from file paths
316
+ # We store it in a format that `respond` can also understand
317
+ # The path is temporary, Gradio handles making it accessible for display
318
+ history_user_entry_content.append({"type": "image_url", "image_url": {"url": file_path_obj.name}})
319
+ print(f"Adding image to history entry: {file_path_obj.name}")
320
+
321
+ if history_user_entry_content:
322
+ history.append([history_user_entry_content, None]) # User part, Bot part (initially None)
323
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
  return history
325
 
 
326
  def bot(history, system_msg, max_tokens, temperature, top_p, freq_penalty, seed, provider, api_key, custom_model, search_term, selected_model):
327
+ if not history or not history[-1][0]: # If no user message or empty user message content
328
+ print("No user message to process in bot function or user message content is empty.")
329
+ yield history # Return current history without processing
330
  return
331
 
332
+ user_content_list = history[-1][0] # This is now a list of content dicts
 
 
 
 
333
 
334
+ # Extract text and image file paths from the user_content_list for the `respond` function
335
+ text_for_respond = ""
336
+ image_files_for_respond = []
337
+
338
+ for item in user_content_list:
339
+ if item["type"] == "text":
340
+ text_for_respond = item["text"]
341
+ elif item["type"] == "image_url":
342
+ image_files_for_respond.append(item["image_url"]["url"])
343
+
344
+ history[-1][1] = "" # Clear placeholder for bot response / Initialize bot response
345
 
346
+ # Call the respond function which is now a generator
347
  for response_chunk in respond(
348
+ text_for_respond,
349
+ image_files_for_respond,
350
+ history[:-1], # Pass previous history
351
+ system_msg, max_tokens, temperature, top_p, freq_penalty, seed,
352
+ provider, api_key, custom_model, search_term, selected_model
 
 
 
 
 
 
 
 
 
353
  ):
354
+ current_bot_response = history[-1][1]
355
+ if isinstance(response_chunk, dict):
356
+ if response_chunk["type"] == "text":
357
+ # If current bot response is already an image dict, we can't append text.
358
+ # This indicates a new text response after an image, or just text.
359
+ if isinstance(current_bot_response, dict) and current_bot_response.get("type") == "image":
360
+ # This case should ideally not happen if an image is the final response from a tool.
361
+ # If it does, we might need to start a new bot message in history.
362
+ # For now, we'll overwrite if the new chunk is text.
363
+ history[-1][1] = response_chunk["content"]
364
+ elif isinstance(current_bot_response, str):
365
+ history[-1][1] = response_chunk["content"] # Accumulate text
366
+ else: # current_bot_response is likely "" or None
367
+ history[-1][1] = response_chunk["content"]
368
+
369
+ elif response_chunk["type"] == "image":
370
+ # Image response from tool. Gradio Chatbot displays this as an image.
371
+ # The path should be accessible by Gradio.
372
+ # If there was prior text content for this turn, it's now overwritten by the image.
373
+ # This means a tool call that produces an image is considered the primary response for that turn.
374
+ history[-1][1] = {"path": response_chunk["path"], "mime_type": "image/jpeg"} # Assuming JPEG, could be PNG
375
  yield history
376
+
 
377
  msg.submit(
378
  user,
379
+ [msg, chatbot],
380
  [chatbot],
381
  queue=False
382
  ).then(
 
386
  model_search_box, featured_model_radio],
387
  [chatbot]
388
  ).then(
389
+ lambda: {"text": "", "files": []}, # Clear MultimodalTextbox
390
  None,
391
  [msg]
392
  )
393
 
394
+ model_search_box.change(fn=filter_models, inputs=model_search_box, outputs=featured_model_radio)
 
 
 
 
 
395
  print("Model search box change event linked.")
396
 
397
+ featured_model_radio.change(fn=set_custom_model_from_radio, inputs=featured_model_radio, outputs=custom_model_box)
 
 
 
 
 
398
  print("Featured model radio button change event linked.")
399
 
400
+ byok_textbox.change(fn=validate_provider, inputs=[byok_textbox, provider_radio], outputs=provider_radio)
 
 
 
 
 
401
  print("BYOK textbox change event linked.")
402
 
403
+ provider_radio.change(fn=validate_provider, inputs=[byok_textbox, provider_radio], outputs=provider_radio)
 
 
 
 
 
404
  print("Provider radio button change event linked.")
405
 
406
  print("Gradio interface initialized.")
407
 
408
  if __name__ == "__main__":
409
  print("Launching the demo application.")
410
+ demo.launch(show_api=True)