Nymbo commited on
Commit
7ab8722
·
verified ·
1 Parent(s): 32ca026

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +456 -292
app.py CHANGED
@@ -1,180 +1,215 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient # Keep for direct use if needed, though agent will use its own model
3
  import os
4
  import json
5
  import base64
6
  from PIL import Image
7
  import io
8
 
9
- # Smolagents imports
10
- from smolagents import CodeAgent, Tool
11
- from smolagents.models import InferenceClientModel as SmolInferenceClientModel
12
- # We'll use PIL.Image directly for opening, AgentImage is for agent's internal typing if needed by a tool
13
- from smolagents.gradio_ui import pull_messages_from_step # For formatting agent steps
14
- from smolagents.memory import ActionStep, FinalAnswerStep, PlanningStep, MemoryStep # For type checking steps
15
- from smolagents.models import ChatMessageStreamDelta # For type checking stream deltas
16
-
17
-
18
  ACCESS_TOKEN = os.getenv("HF_TOKEN")
19
  print("Access token loaded.")
20
 
21
- # Function to encode image to base64 (remains useful if we ever need to pass base64 to a non-smolagent component)
22
- def encode_image(image_path_or_pil):
23
- if not image_path_or_pil:
24
- print("No image path or PIL Image provided")
25
  return None
26
 
27
  try:
28
- # print(f"Encoding image: {type(image_path_or_pil)}") # Debug
29
 
30
- if isinstance(image_path_or_pil, Image.Image):
31
- image = image_path_or_pil
32
- else: # Assuming it's a path
33
- image = Image.open(image_path_or_pil)
 
 
34
 
 
35
  if image.mode == 'RGBA':
36
  image = image.convert('RGB')
37
 
 
38
  buffered = io.BytesIO()
39
- image.save(buffered, format="JPEG") # JPEG is generally smaller for transfer
40
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
41
- # print("Image encoded successfully") # Debug
42
  return img_str
43
  except Exception as e:
44
  print(f"Error encoding image: {e}")
45
  return None
46
 
47
- # This function will now set up and run the smolagent
48
  def respond(
49
- message_text, # Text from MultimodalTextbox
50
- image_file_paths, # List of file paths from MultimodalTextbox
51
- gradio_history: list[tuple[str, str]], # Gradio history (for context if needed, agent is stateless per call here)
52
- system_message_for_agent, # System prompt for the main LLM agent
53
  max_tokens,
54
  temperature,
55
  top_p,
56
  frequency_penalty,
57
  seed,
58
- provider_for_agent_llm,
59
- api_key_for_agent_llm,
60
- model_id_for_agent_llm,
61
- model_search_term, # Unused directly by agent logic
62
- selected_model_for_agent_llm # Fallback model ID
63
  ):
64
- print(f"Respond function called. Message: '{message_text}', Images: {image_file_paths}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- token_to_use = api_key_for_agent_llm if api_key_for_agent_llm.strip() != "" else ACCESS_TOKEN
67
- model_to_use = model_id_for_agent_llm.strip() if model_id_for_agent_llm.strip() != "" else selected_model_for_agent_llm
 
68
 
69
- # --- Initialize the LLM for the CodeAgent ---
70
- agent_llm_params = {
71
- "model_id": model_to_use,
72
- "token": token_to_use,
73
- # smolagents's InferenceClientModel uses max_tokens for max_new_tokens
 
74
  "max_tokens": max_tokens,
75
- "temperature": temperature if temperature > 0.01 else None, # Some models require temp > 0
76
- "top_p": top_p if top_p < 1.0 else None, # Often 1.0 means no top_p
77
- "seed": seed if seed != -1 else None,
78
  }
79
- if provider_for_agent_llm and provider_for_agent_llm != "hf-inference":
80
- agent_llm_params["provider"] = provider_for_agent_llm
81
 
82
- # HFIC specific params, add if not default and supported
83
- if frequency_penalty != 0.0:
84
- agent_llm_params["frequency_penalty"] = frequency_penalty
85
-
86
- agent_llm = SmolInferenceClientModel(**agent_llm_params)
87
- print(f"Smolagents LLM for agent initialized: model='{model_to_use}', provider='{provider_for_agent_llm or 'default'}'")
88
 
89
- # --- Define Tools for the Agent ---
90
- agent_tools = []
91
  try:
92
- image_gen_tool = Tool.from_space(
93
- space_id="black-forest-labs/FLUX.1-schnell",
94
- name="image_generator",
95
- description="Generates an image from a textual prompt. Input is a single string argument named 'prompt'. Output is an image file path.",
96
- token=token_to_use
 
97
  )
98
- agent_tools.append(image_gen_tool)
99
- print("Image generation tool loaded: black-forest-labs/FLUX.1-schnell")
100
- except Exception as e:
101
- print(f"Error loading image generation tool: {e}")
102
- yield f"Error: Could not load image generation tool. {e}"
103
- return
104
-
105
- # --- Initialize the CodeAgent ---
106
- # If system_message_for_agent is empty, CodeAgent will use its default.
107
- # The default is usually good as it explains how to use tools.
108
- agent = CodeAgent(
109
- tools=agent_tools,
110
- model=agent_llm,
111
- system_prompt=system_message_for_agent if system_message_for_agent and system_message_for_agent.strip() else None,
112
- # add_base_tools=True, # Consider adding Python interpreter, etc.
113
- stream_outputs=True # Important for Gradio streaming
114
- )
115
- print("Smolagents CodeAgent initialized.")
116
-
117
- # --- Prepare task and image inputs for the agent ---
118
- agent_task_text = message_text
119
-
120
- pil_images_for_agent = []
121
- if image_file_paths:
122
- for file_path in image_file_paths:
123
- try:
124
- pil_images_for_agent.append(Image.open(file_path))
125
- except Exception as e:
126
- print(f"Error opening image file {file_path} for agent: {e}")
127
-
128
- print(f"Agent task: '{agent_task_text}'")
129
- if pil_images_for_agent:
130
- print(f"Passing {len(pil_images_for_agent)} image(s) to agent.")
131
-
132
- # --- Run the agent and stream response ---
133
- # Agent is reset each turn. For conversational memory, agent instance
134
- # would need to be stored in session_state and agent.run(..., reset=False) used.
135
-
136
- current_agent_response_text = ""
137
- try:
138
- # The agent.run method returns a generator when stream=True
139
- for step_item in agent.run(
140
- task=agent_task_text,
141
- images=pil_images_for_agent,
142
- stream=True,
143
- reset=True # Explicitly reset for stateless operation per call
144
- ):
145
- if isinstance(step_item, ChatMessageStreamDelta):
146
- if step_item.content:
147
- current_agent_response_text += step_item.content
148
- yield current_agent_response_text # Yield accumulated text
149
-
150
- elif isinstance(step_item, (ActionStep, PlanningStep, FinalAnswerStep)):
151
- # A structured step. Format it for Gradio.
152
- # pull_messages_from_step yields gr.ChatMessage objects.
153
- for gradio_chat_msg in pull_messages_from_step(step_item, skip_model_outputs=agent.stream_outputs):
154
- # The 'bot' function will handle these gr.ChatMessage objects.
155
- yield gradio_chat_msg # Yield the gr.ChatMessage object directly
156
- current_agent_response_text = "" # Reset text buffer after a structured step
157
-
158
- # else:
159
- # print(f"Unhandled stream item type: {type(step_item)}") # Debug
160
-
161
- # If there's any remaining text not part of a gr.ChatMessage, yield it.
162
- # This usually shouldn't happen if stream_to_gradio logic is followed,
163
- # as text deltas should be part of the last gr.ChatMessage or yielded before it.
164
- # However, if the agent's final textual answer comes as pure deltas after all steps.
165
- if current_agent_response_text and not isinstance(step_item, FinalAnswerStep):
166
- # Check if the last yielded item already contains this text
167
- if not (isinstance(step_item, gr.ChatMessage) and step_item.content == current_agent_response_text):
168
- yield current_agent_response_text
169
-
170
-
171
  except Exception as e:
172
- error_message = f"Error during agent execution: {str(e)}"
173
- print(error_message)
174
- yield error_message # Yield the error message to be displayed in UI
175
-
176
- print("Agent run completed.")
177
 
 
178
 
179
  # Function to validate provider selection based on BYOK
180
  def validate_provider(api_key, provider):
@@ -184,15 +219,16 @@ def validate_provider(api_key, provider):
184
 
185
  # GRADIO UI
186
  with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
 
187
  chatbot = gr.Chatbot(
188
  height=600,
189
  show_copy_button=True,
190
- placeholder="Select a model and begin chatting. Now uses smolagents with tools!",
191
- layout="panel",
192
- bubble_full_width=False # For better display of images/files
193
  )
194
  print("Chatbot interface created.")
195
 
 
196
  msg = gr.MultimodalTextbox(
197
  placeholder="Type a message or upload images...",
198
  show_label=False,
@@ -203,201 +239,329 @@ with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
203
  sources=["upload"]
204
  )
205
 
 
 
 
206
  with gr.Accordion("Settings", open=False):
 
207
  system_message_box = gr.Textbox(
208
- value="You are a helpful AI assistant. You can generate images if asked. Be precise with your prompts for image generation.",
209
- placeholder="You are a helpful AI assistant.",
210
- label="System Prompt for Agent"
211
  )
212
 
 
213
  with gr.Row():
214
  with gr.Column():
215
- max_tokens_slider = gr.Slider(minimum=1, maximum=4096, value=1024, step=1, label="Max New Tokens")
216
- temperature_slider = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.01, label="Temperature")
217
- top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.01, label="Top-P")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  with gr.Column():
219
- frequency_penalty_slider = gr.Slider(minimum=-2.0, maximum=2.0, value=0.0, step=0.1, label="Frequency Penalty")
220
- seed_slider = gr.Slider(minimum=-1, maximum=65535, value=-1, step=1, label="Seed (-1 for random)")
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
 
222
  providers_list = [
223
- "hf-inference", "cerebras", "together", "sambanova", "novita",
224
- "cohere", "fireworks-ai", "hyperbolic", "nebius",
 
 
 
 
 
 
 
225
  ]
226
- provider_radio = gr.Radio(choices=providers_list, value="hf-inference", label="Inference Provider for Agent's LLM")
227
- byok_textbox = gr.Textbox(value="", label="BYOK (Your HF Token or Provider API Key)", info="Enter API key for the selected provider. Uses HF_TOKEN if empty.", placeholder="Enter your API token", type="password")
228
- custom_model_box = gr.Textbox(value="", label="Custom Model ID for Agent's LLM", info="(Optional) Provide a custom model ID. Overrides featured model.", placeholder="meta-llama/Llama-3.3-70B-Instruct")
229
- model_search_box = gr.Textbox(label="Filter Featured Models", placeholder="Search for a featured model...", lines=1)
230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  models_list = [
232
- "meta-llama/Llama-3.3-70B-Instruct", "meta-llama/Llama-3.1-70B-Instruct", "meta-llama/Llama-3.0-70B-Instruct",
233
- "meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.2-3B-Instruct", "meta-llama/Llama-3.2-1B-Instruct",
234
- "meta-llama/Llama-3.1-8B-Instruct", "NousResearch/Hermes-3-Llama-3.1-8B", "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
235
- "mistralai/Mistral-Nemo-Instruct-2407", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.3",
236
- "Qwen/Qwen3-235B-A22B", "Qwen/Qwen3-32B", "Qwen/Qwen2.5-72B-Instruct", "Qwen/Qwen2.5-3B-Instruct",
237
- "Qwen/Qwen2.5-Coder-32B-Instruct", "microsoft/Phi-3.5-mini-instruct", "microsoft/Phi-3-mini-128k-instruct",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  ]
239
- featured_model_radio = gr.Radio(label="Select a Featured Model for Agent's LLM", choices=models_list, value="meta-llama/Llama-3.3-70B-Instruct", interactive=True)
 
 
 
 
 
 
240
 
241
  gr.Markdown("[View all Text-to-Text models](https://huggingface.co/models?inference_provider=all&pipeline_tag=text-generation&sort=trending) | [View all multimodal models](https://huggingface.co/models?inference_provider=all&pipeline_tag=image-text-to-text&sort=trending)")
242
 
243
- # Chat history state (using gr.State to manage it properly)
244
- # The chatbot's value itself will be the history display.
245
- # We might need a separate gr.State if agent needs to be conversational across turns.
246
- # For now, agent is stateless per turn.
 
 
 
 
 
 
 
 
 
 
247
 
248
  # Function for the chat interface
249
- def user(user_multimodal_input_dict, history):
250
- print(f"User input: {user_multimodal_input_dict}")
251
- text_content = user_multimodal_input_dict.get("text", "")
252
- files = user_multimodal_input_dict.get("files", [])
253
 
254
- user_display_parts = []
255
- if text_content and text_content.strip():
256
- user_display_parts.append(text_content)
257
- for file_path_obj in files: # file_path_obj is a tempfile._TemporaryFileWrapper
258
- user_display_parts.append((file_path_obj.name, os.path.basename(file_path_obj.name)))
259
-
260
- if not user_display_parts:
 
 
 
 
 
 
 
 
261
  return history
 
 
 
 
 
 
 
 
262
 
263
- # Append the user's multimodal message to history for display
264
- # The actual data (dict) is passed to `bot` function separately.
265
- history.append([user_display_parts if len(user_display_parts) > 1 else user_display_parts[0], None])
266
- return history
267
-
 
 
 
 
 
 
 
 
 
 
268
  def bot(history, system_msg, max_tokens, temperature, top_p, freq_penalty, seed, provider, api_key, custom_model, search_term, selected_model):
269
- if not history or not history[-1][0]: # If no user input
270
- yield history
271
- return
272
-
273
- # The user's input (text and list of file paths) is in history[-1][0]
274
- # If `user` function stores the dict:
275
- raw_user_input_dict = history[-1][0] if isinstance(history[-1][0], dict) else {"text": str(history[-1][0]), "files": []}
276
 
277
- # If `user` function stores formatted display parts:
278
- # We need to reconstruct or rely on msg input to bot.
279
- # For now, assuming msg.submit passes the raw dict.
280
- # Let's adjust the Gradio flow to pass `msg` directly to `bot` as well.
281
-
282
- # The `msg` variable in `msg.submit` holds the raw MultimodalTextbox output.
283
- # We need to pass this raw dict to `respond`.
284
- # The `history` is for display.
285
 
286
- # This part is tricky as `bot` gets `history` which is already formatted for display.
287
- # A common pattern is to pass `msg` (raw input) also to `bot`.
288
- # Let's assume `history[-1][0]` contains enough info or we adjust `user` fn.
289
- # For simplicity, let's assume `user` stores the raw dict if needed,
290
- # or `bot` can parse `history[-1][0]` if it's a string/list of tuples.
291
-
292
- # Let's assume `history[-1][0]` is the raw `user_multimodal_input_dict`
293
- # This means the `user` function must append it like: `history.append([user_multimodal_input_dict, None])`
294
- # And the chatbot will display `str(user_multimodal_input_dict)`.
295
- # This is what the current `user` function does.
296
-
297
- user_input_data = history[-1][0] # This should be the dict from MultimodalTextbox
298
- text_input_for_agent = user_input_data.get("text", "")
299
- # Files from MultimodalTextbox are temp file paths
300
- image_file_paths_for_agent = [f.name for f in user_input_data.get("files", []) if hasattr(f, 'name')]
301
-
302
-
303
- history[-1][1] = "" # Initialize assistant's part for streaming
304
 
305
- # Buffer for current text stream from agent
306
- # Handles both pure text deltas and text content from gr.ChatMessage
307
- current_text_for_turn = ""
308
-
309
- for item in respond(
310
- message_text=text_input_for_agent,
311
- image_file_paths=image_file_paths_for_agent,
312
- gradio_history=history[:-1], # Pass previous turns for context if agent uses it
313
- system_message_for_agent=system_msg,
314
- max_tokens=max_tokens, temperature=temperature, top_p=top_p,
315
- frequency_penalty=freq_penalty, seed=seed,
316
- provider_for_agent_llm=provider, api_key_for_agent_llm=api_key,
317
- model_id_for_agent_llm=custom_model,
318
- model_search_term=search_term, # unused
319
- selected_model_for_agent_llm=selected_model
320
- ):
321
- if isinstance(item, str): # LLM text delta from agent's thought or textual answer
322
- current_text_for_turn = item
323
- history[-1][1] = current_text_for_turn
324
- elif isinstance(item, gr.ChatMessage):
325
- # This is a structured step (thought, tool output, image, etc.)
326
- # We need to append this to the history as a new message or part of current message.
327
- # For simplicity, let's append its string content to the current turn's assistant message.
328
- # If it's an image/file, we'll represent it as a markdown link.
329
- if isinstance(item.content, str):
330
- current_text_for_turn = item.content # Replace if it's a full message
331
- elif isinstance(item.content, dict) and "path" in item.content:
332
- # This is typically an image or audio file
333
- file_path = item.content["path"]
334
- # We need to make this file accessible to Gradio if it's temporary from agent
335
- # For now, just put a placeholder.
336
- # If it's an output from a tool, the path might be relative to where smolagents saves it.
337
- # Gradio needs an absolute path or a URL.
338
- # A common pattern is to copy temp files to a static dir served by Gradio or use gr.File.
339
- # For now, let's assume Gradio can handle local paths if they are in a folder it knows.
340
- # We'll display it as a tuple for Gradio Chatbot.
341
- # This means history[-1][1] needs to become a list.
342
-
343
- # If current_text_for_turn is not empty, make history[-1][1] a list
344
- if current_text_for_turn and not isinstance(history[-1][1], list):
345
- history[-1][1] = [current_text_for_turn]
346
- elif not current_text_for_turn and not isinstance(history[-1][1], list):
347
- history[-1][1] = []
348
-
349
-
350
- alt_text = item.metadata.get("title", os.path.basename(file_path)) if item.metadata else os.path.basename(file_path)
351
-
352
- # Add as new component to the list for current assistant message
353
- if isinstance(history[-1][1], list):
354
- history[-1][1].append((file_path, alt_text))
355
- else: # Should have been made a list above
356
- history[-1][1] = [(file_path, alt_text)]
357
-
358
- current_text_for_turn = "" # Reset text buffer after a file
359
-
360
- # If it's not a delta, but a full message, replace the current text
361
- if not isinstance(history[-1][1], list): # if it hasn't become a list due to file
362
- history[-1][1] = current_text_for_turn
363
-
364
- yield history
365
-
366
- # Event handlers
367
- # `msg.submit`'s first argument is the function to call.
368
- # Its `inputs` are the Gradio components whose values are passed to the function.
369
- # Its `outputs` are the Gradio components that are updated by the function's return value.
370
- # The `user` function now appends the raw dict from MultimodalTextbox to history.
371
- # The `bot` function takes this history.
372
-
373
- # When msg is submitted:
374
- # 1. Call `user` to update history with user's input. Output is `chatbot`.
375
- # 2. Then call `bot` with the updated history. Output is `chatbot`.
376
- # 3. Then clear `msg`
377
  msg.submit(
378
  user,
379
  [msg, chatbot],
380
- [chatbot], # `user` returns the new history, updating the chatbot display
381
  queue=False
382
  ).then(
383
  bot,
384
  [chatbot, system_message_box, max_tokens_slider, temperature_slider, top_p_slider,
385
  frequency_penalty_slider, seed_slider, provider_radio, byok_textbox, custom_model_box,
386
  model_search_box, featured_model_radio],
387
- [chatbot] # `bot` yields history updates, streaming to chatbot
388
  ).then(
389
- lambda: {"text": "", "files": []}, # Clear MultimodalTextbox
390
  None,
391
  [msg]
392
  )
393
 
394
- model_search_box.change(fn=filter_models, inputs=model_search_box, outputs=featured_model_radio)
395
- featured_model_radio.change(fn=set_custom_model_from_radio, inputs=featured_model_radio, outputs=custom_model_box)
396
- byok_textbox.change(fn=validate_provider, inputs=[byok_textbox, provider_radio], outputs=provider_radio)
397
- provider_radio.change(fn=validate_provider, inputs=[byok_textbox, provider_radio], outputs=provider_radio)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
 
399
  print("Gradio interface initialized.")
400
 
401
  if __name__ == "__main__":
402
  print("Launching the demo application.")
403
- demo.launch(show_api=False) # show_api=False for cleaner launch, True for API docs
 
1
  import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
  import os
4
  import json
5
  import base64
6
  from PIL import Image
7
  import io
8
 
 
 
 
 
 
 
 
 
 
9
  ACCESS_TOKEN = os.getenv("HF_TOKEN")
10
  print("Access token loaded.")
11
 
12
+ # Function to encode image to base64
13
+ def encode_image(image_path):
14
+ if not image_path:
15
+ print("No image path provided")
16
  return None
17
 
18
  try:
19
+ print(f"Encoding image from path: {image_path}")
20
 
21
+ # If it's already a PIL Image
22
+ if isinstance(image_path, Image.Image):
23
+ image = image_path
24
+ else:
25
+ # Try to open the image file
26
+ image = Image.open(image_path)
27
 
28
+ # Convert to RGB if image has an alpha channel (RGBA)
29
  if image.mode == 'RGBA':
30
  image = image.convert('RGB')
31
 
32
+ # Encode to base64
33
  buffered = io.BytesIO()
34
+ image.save(buffered, format="JPEG")
35
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
36
+ print("Image encoded successfully")
37
  return img_str
38
  except Exception as e:
39
  print(f"Error encoding image: {e}")
40
  return None
41
 
 
42
  def respond(
43
+ message,
44
+ image_files, # Changed parameter name and structure
45
+ history: list[tuple[str, str]],
46
+ system_message,
47
  max_tokens,
48
  temperature,
49
  top_p,
50
  frequency_penalty,
51
  seed,
52
+ provider,
53
+ custom_api_key,
54
+ custom_model,
55
+ model_search_term,
56
+ selected_model
57
  ):
58
+ print(f"Received message: {message}")
59
+ print(f"Received {len(image_files) if image_files else 0} images")
60
+ print(f"History: {history}")
61
+ print(f"System message: {system_message}")
62
+ print(f"Max tokens: {max_tokens}, Temperature: {temperature}, Top-P: {top_p}")
63
+ print(f"Frequency Penalty: {frequency_penalty}, Seed: {seed}")
64
+ print(f"Selected provider: {provider}")
65
+ print(f"Custom API Key provided: {bool(custom_api_key.strip())}")
66
+ print(f"Selected model (custom_model): {custom_model}")
67
+ print(f"Model search term: {model_search_term}")
68
+ print(f"Selected model from radio: {selected_model}")
69
+
70
+ # Determine which token to use
71
+ token_to_use = custom_api_key if custom_api_key.strip() != "" else ACCESS_TOKEN
72
+
73
+ if custom_api_key.strip() != "":
74
+ print("USING CUSTOM API KEY: BYOK token provided by user is being used for authentication")
75
+ else:
76
+ print("USING DEFAULT API KEY: Environment variable HF_TOKEN is being used for authentication")
77
+
78
+ # Initialize the Inference Client with the provider and appropriate token
79
+ client = InferenceClient(token=token_to_use, provider=provider)
80
+ print(f"Hugging Face Inference Client initialized with {provider} provider.")
81
+
82
+ # Convert seed to None if -1 (meaning random)
83
+ if seed == -1:
84
+ seed = None
85
+
86
+ # Create multimodal content if images are present
87
+ if image_files and len(image_files) > 0:
88
+ # Process the user message to include images
89
+ user_content = []
90
+
91
+ # Add text part if there is any
92
+ if message and message.strip():
93
+ user_content.append({
94
+ "type": "text",
95
+ "text": message
96
+ })
97
+
98
+ # Add image parts
99
+ for img in image_files:
100
+ if img is not None:
101
+ # Get raw image data from path
102
+ try:
103
+ encoded_image = encode_image(img)
104
+ if encoded_image:
105
+ user_content.append({
106
+ "type": "image_url",
107
+ "image_url": {
108
+ "url": f"data:image/jpeg;base64,{encoded_image}"
109
+ }
110
+ })
111
+ except Exception as e:
112
+ print(f"Error encoding image: {e}")
113
+ else:
114
+ # Text-only message
115
+ user_content = message
116
+
117
+ # Prepare messages in the format expected by the API
118
+ messages = [{"role": "system", "content": system_message}]
119
+ print("Initial messages array constructed.")
120
+
121
+ # Add conversation history to the context
122
+ for val in history:
123
+ user_part = val[0]
124
+ assistant_part = val[1]
125
+ if user_part:
126
+ # Handle both text-only and multimodal messages in history
127
+ if isinstance(user_part, tuple) and len(user_part) == 2:
128
+ # This is a multimodal message with text and images
129
+ history_content = []
130
+ if user_part[0]: # Text
131
+ history_content.append({
132
+ "type": "text",
133
+ "text": user_part[0]
134
+ })
135
+
136
+ for img in user_part[1]: # Images
137
+ if img:
138
+ try:
139
+ encoded_img = encode_image(img)
140
+ if encoded_img:
141
+ history_content.append({
142
+ "type": "image_url",
143
+ "image_url": {
144
+ "url": f"data:image/jpeg;base64,{encoded_img}"
145
+ }
146
+ })
147
+ except Exception as e:
148
+ print(f"Error encoding history image: {e}")
149
+
150
+ messages.append({"role": "user", "content": history_content})
151
+ else:
152
+ # Regular text message
153
+ messages.append({"role": "user", "content": user_part})
154
+ print(f"Added user message to context (type: {type(user_part)})")
155
+
156
+ if assistant_part:
157
+ messages.append({"role": "assistant", "content": assistant_part})
158
+ print(f"Added assistant message to context: {assistant_part}")
159
+
160
+ # Append the latest user message
161
+ messages.append({"role": "user", "content": user_content})
162
+ print(f"Latest user message appended (content type: {type(user_content)})")
163
 
164
+ # Determine which model to use, prioritizing custom_model if provided
165
+ model_to_use = custom_model.strip() if custom_model.strip() != "" else selected_model
166
+ print(f"Model selected for inference: {model_to_use}")
167
 
168
+ # Start with an empty string to build the response as tokens stream in
169
+ response = ""
170
+ print(f"Sending request to {provider} provider.")
171
+
172
+ # Prepare parameters for the chat completion request
173
+ parameters = {
174
  "max_tokens": max_tokens,
175
+ "temperature": temperature,
176
+ "top_p": top_p,
177
+ "frequency_penalty": frequency_penalty,
178
  }
 
 
179
 
180
+ if seed is not None:
181
+ parameters["seed"] = seed
 
 
 
 
182
 
183
+ # Use the InferenceClient for making the request
 
184
  try:
185
+ # Create a generator for the streaming response
186
+ stream = client.chat_completion(
187
+ model=model_to_use,
188
+ messages=messages,
189
+ stream=True,
190
+ **parameters
191
  )
192
+
193
+ print("Received tokens: ", end="", flush=True)
194
+
195
+ # Process the streaming response
196
+ for chunk in stream:
197
+ if hasattr(chunk, 'choices') and len(chunk.choices) > 0:
198
+ # Extract the content from the response
199
+ if hasattr(chunk.choices[0], 'delta') and hasattr(chunk.choices[0].delta, 'content'):
200
+ token_text = chunk.choices[0].delta.content
201
+ if token_text:
202
+ print(token_text, end="", flush=True)
203
+ response += token_text
204
+ yield response
205
+
206
+ print()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  except Exception as e:
208
+ print(f"Error during inference: {e}")
209
+ response += f"\nError: {str(e)}"
210
+ yield response
 
 
211
 
212
+ print("Completed response generation.")
213
 
214
  # Function to validate provider selection based on BYOK
215
  def validate_provider(api_key, provider):
 
219
 
220
  # GRADIO UI
221
  with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
222
+ # Create the chatbot component
223
  chatbot = gr.Chatbot(
224
  height=600,
225
  show_copy_button=True,
226
+ placeholder="Select a model and begin chatting. Now supports multiple inference providers and multimodal inputs",
227
+ layout="panel"
 
228
  )
229
  print("Chatbot interface created.")
230
 
231
+ # Multimodal textbox for messages (combines text and file uploads)
232
  msg = gr.MultimodalTextbox(
233
  placeholder="Type a message or upload images...",
234
  show_label=False,
 
239
  sources=["upload"]
240
  )
241
 
242
+ # Note: We're removing the separate submit button since MultimodalTextbox has its own
243
+
244
+ # Create accordion for settings
245
  with gr.Accordion("Settings", open=False):
246
+ # System message
247
  system_message_box = gr.Textbox(
248
+ value="You are a helpful AI assistant that can understand images and text.",
249
+ placeholder="You are a helpful assistant.",
250
+ label="System Prompt"
251
  )
252
 
253
+ # Generation parameters
254
  with gr.Row():
255
  with gr.Column():
256
+ max_tokens_slider = gr.Slider(
257
+ minimum=1,
258
+ maximum=4096,
259
+ value=512,
260
+ step=1,
261
+ label="Max tokens"
262
+ )
263
+
264
+ temperature_slider = gr.Slider(
265
+ minimum=0.1,
266
+ maximum=4.0,
267
+ value=0.7,
268
+ step=0.1,
269
+ label="Temperature"
270
+ )
271
+
272
+ top_p_slider = gr.Slider(
273
+ minimum=0.1,
274
+ maximum=1.0,
275
+ value=0.95,
276
+ step=0.05,
277
+ label="Top-P"
278
+ )
279
+
280
  with gr.Column():
281
+ frequency_penalty_slider = gr.Slider(
282
+ minimum=-2.0,
283
+ maximum=2.0,
284
+ value=0.0,
285
+ step=0.1,
286
+ label="Frequency Penalty"
287
+ )
288
+
289
+ seed_slider = gr.Slider(
290
+ minimum=-1,
291
+ maximum=65535,
292
+ value=-1,
293
+ step=1,
294
+ label="Seed (-1 for random)"
295
+ )
296
 
297
+ # Provider selection
298
  providers_list = [
299
+ "hf-inference", # Default Hugging Face Inference
300
+ "cerebras", # Cerebras provider
301
+ "together", # Together AI
302
+ "sambanova", # SambaNova
303
+ "novita", # Novita AI
304
+ "cohere", # Cohere
305
+ "fireworks-ai", # Fireworks AI
306
+ "hyperbolic", # Hyperbolic
307
+ "nebius", # Nebius
308
  ]
 
 
 
 
309
 
310
+ provider_radio = gr.Radio(
311
+ choices=providers_list,
312
+ value="hf-inference",
313
+ label="Inference Provider",
314
+ )
315
+
316
+ # New BYOK textbox
317
+ byok_textbox = gr.Textbox(
318
+ value="",
319
+ label="BYOK (Bring Your Own Key)",
320
+ info="Enter a custom Hugging Face API key here. When empty, only 'hf-inference' provider can be used.",
321
+ placeholder="Enter your Hugging Face API token",
322
+ type="password" # Hide the API key for security
323
+ )
324
+
325
+ # Custom model box
326
+ custom_model_box = gr.Textbox(
327
+ value="",
328
+ label="Custom Model",
329
+ info="(Optional) Provide a custom Hugging Face model path. Overrides any selected featured model.",
330
+ placeholder="meta-llama/Llama-3.3-70B-Instruct"
331
+ )
332
+
333
+ # Model search
334
+ model_search_box = gr.Textbox(
335
+ label="Filter Models",
336
+ placeholder="Search for a featured model...",
337
+ lines=1
338
+ )
339
+
340
+ # Featured models list
341
+ # Updated to include multimodal models
342
  models_list = [
343
+ "meta-llama/Llama-3.2-11B-Vision-Instruct",
344
+ "meta-llama/Llama-3.3-70B-Instruct",
345
+ "meta-llama/Llama-3.1-70B-Instruct",
346
+ "meta-llama/Llama-3.0-70B-Instruct",
347
+ "meta-llama/Llama-3.2-3B-Instruct",
348
+ "meta-llama/Llama-3.2-1B-Instruct",
349
+ "meta-llama/Llama-3.1-8B-Instruct",
350
+ "NousResearch/Hermes-3-Llama-3.1-8B",
351
+ "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
352
+ "mistralai/Mistral-Nemo-Instruct-2407",
353
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
354
+ "mistralai/Mistral-7B-Instruct-v0.3",
355
+ "mistralai/Mistral-7B-Instruct-v0.2",
356
+ "Qwen/Qwen3-235B-A22B",
357
+ "Qwen/Qwen3-32B",
358
+ "Qwen/Qwen2.5-72B-Instruct",
359
+ "Qwen/Qwen2.5-3B-Instruct",
360
+ "Qwen/Qwen2.5-0.5B-Instruct",
361
+ "Qwen/QwQ-32B",
362
+ "Qwen/Qwen2.5-Coder-32B-Instruct",
363
+ "microsoft/Phi-3.5-mini-instruct",
364
+ "microsoft/Phi-3-mini-128k-instruct",
365
+ "microsoft/Phi-3-mini-4k-instruct",
366
  ]
367
+
368
+ featured_model_radio = gr.Radio(
369
+ label="Select a model below",
370
+ choices=models_list,
371
+ value="meta-llama/Llama-3.2-11B-Vision-Instruct", # Default to a multimodal model
372
+ interactive=True
373
+ )
374
 
375
  gr.Markdown("[View all Text-to-Text models](https://huggingface.co/models?inference_provider=all&pipeline_tag=text-generation&sort=trending) | [View all multimodal models](https://huggingface.co/models?inference_provider=all&pipeline_tag=image-text-to-text&sort=trending)")
376
 
377
+ # Chat history state
378
+ chat_history = gr.State([])
379
+
380
+ # Function to filter models
381
+ def filter_models(search_term):
382
+ print(f"Filtering models with search term: {search_term}")
383
+ filtered = [m for m in models_list if search_term.lower() in m.lower()]
384
+ print(f"Filtered models: {filtered}")
385
+ return gr.update(choices=filtered)
386
+
387
+ # Function to set custom model from radio
388
+ def set_custom_model_from_radio(selected):
389
+ print(f"Featured model selected: {selected}")
390
+ return selected
391
 
392
  # Function for the chat interface
393
+ def user(user_message, history):
394
+ # Debug logging for troubleshooting
395
+ print(f"User message received: {user_message}")
 
396
 
397
+ # Skip if message is empty (no text and no files)
398
+ if not user_message or (not user_message.get("text") and not user_message.get("files")):
399
+ print("Empty message, skipping")
400
+ return history
401
+
402
+ # Prepare multimodal message format
403
+ text_content = user_message.get("text", "").strip()
404
+ files = user_message.get("files", [])
405
+
406
+ print(f"Text content: {text_content}")
407
+ print(f"Files: {files}")
408
+
409
+ # If both text and files are empty, skip
410
+ if not text_content and not files:
411
+ print("No content to display")
412
  return history
413
+
414
+ # Add message with images to history
415
+ if files and len(files) > 0:
416
+ # Add text message first if it exists
417
+ if text_content:
418
+ # Add a separate text message
419
+ print(f"Adding text message: {text_content}")
420
+ history.append([text_content, None])
421
 
422
+ # Then add each image file separately
423
+ for file_path in files:
424
+ if file_path and isinstance(file_path, str):
425
+ print(f"Adding image: {file_path}")
426
+ # Add image as a separate message with no text
427
+ history.append([f"![Image]({file_path})", None])
428
+
429
+ return history
430
+ else:
431
+ # For text-only messages
432
+ print(f"Adding text-only message: {text_content}")
433
+ history.append([text_content, None])
434
+ return history
435
+
436
+ # Define bot response function
437
  def bot(history, system_msg, max_tokens, temperature, top_p, freq_penalty, seed, provider, api_key, custom_model, search_term, selected_model):
438
+ # Check if history is valid
439
+ if not history or len(history) == 0:
440
+ print("No history to process")
441
+ return history
 
 
 
442
 
443
+ # Get the most recent message and detect if it's an image
444
+ user_message = history[-1][0]
445
+ print(f"Processing user message: {user_message}")
 
 
 
 
 
446
 
447
+ is_image = False
448
+ image_path = None
449
+ text_content = user_message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
 
451
+ # Check if this is an image message (marked with ![Image])
452
+ if isinstance(user_message, str) and user_message.startswith("![Image]("):
453
+ is_image = True
454
+ # Extract image path from markdown format ![Image](path)
455
+ image_path = user_message.replace("![Image](", "").replace(")", "")
456
+ print(f"Image detected: {image_path}")
457
+ text_content = "" # No text for image-only messages
458
+
459
+ # Look back for text context if this is an image
460
+ text_context = ""
461
+ if is_image and len(history) > 1:
462
+ # Use the previous message as context if it's text
463
+ prev_message = history[-2][0]
464
+ if isinstance(prev_message, str) and not prev_message.startswith("![Image]("):
465
+ text_context = prev_message
466
+ print(f"Using text context from previous message: {text_context}")
467
+
468
+ # Process message through respond function
469
+ history[-1][1] = ""
470
+
471
+ # Use either the image or text for the API
472
+ if is_image:
473
+ # For image messages
474
+ for response in respond(
475
+ text_context, # Text context from previous message if any
476
+ [image_path], # Current image
477
+ history[:-1], # Previous history
478
+ system_msg,
479
+ max_tokens,
480
+ temperature,
481
+ top_p,
482
+ freq_penalty,
483
+ seed,
484
+ provider,
485
+ api_key,
486
+ custom_model,
487
+ search_term,
488
+ selected_model
489
+ ):
490
+ history[-1][1] = response
491
+ yield history
492
+ else:
493
+ # For text-only messages
494
+ for response in respond(
495
+ text_content, # Text message
496
+ None, # No image
497
+ history[:-1], # Previous history
498
+ system_msg,
499
+ max_tokens,
500
+ temperature,
501
+ top_p,
502
+ freq_penalty,
503
+ seed,
504
+ provider,
505
+ api_key,
506
+ custom_model,
507
+ search_term,
508
+ selected_model
509
+ ):
510
+ history[-1][1] = response
511
+ yield history
512
+
513
+ # Event handlers - only using the MultimodalTextbox's built-in submit functionality
 
 
 
 
 
 
 
 
 
514
  msg.submit(
515
  user,
516
  [msg, chatbot],
517
+ [chatbot],
518
  queue=False
519
  ).then(
520
  bot,
521
  [chatbot, system_message_box, max_tokens_slider, temperature_slider, top_p_slider,
522
  frequency_penalty_slider, seed_slider, provider_radio, byok_textbox, custom_model_box,
523
  model_search_box, featured_model_radio],
524
+ [chatbot]
525
  ).then(
526
+ lambda: {"text": "", "files": []}, # Clear inputs after submission
527
  None,
528
  [msg]
529
  )
530
 
531
+ # Connect the model filter to update the radio choices
532
+ model_search_box.change(
533
+ fn=filter_models,
534
+ inputs=model_search_box,
535
+ outputs=featured_model_radio
536
+ )
537
+ print("Model search box change event linked.")
538
+
539
+ # Connect the featured model radio to update the custom model box
540
+ featured_model_radio.change(
541
+ fn=set_custom_model_from_radio,
542
+ inputs=featured_model_radio,
543
+ outputs=custom_model_box
544
+ )
545
+ print("Featured model radio button change event linked.")
546
+
547
+ # Connect the BYOK textbox to validate provider selection
548
+ byok_textbox.change(
549
+ fn=validate_provider,
550
+ inputs=[byok_textbox, provider_radio],
551
+ outputs=provider_radio
552
+ )
553
+ print("BYOK textbox change event linked.")
554
+
555
+ # Also validate provider when the radio changes to ensure consistency
556
+ provider_radio.change(
557
+ fn=validate_provider,
558
+ inputs=[byok_textbox, provider_radio],
559
+ outputs=provider_radio
560
+ )
561
+ print("Provider radio button change event linked.")
562
 
563
  print("Gradio interface initialized.")
564
 
565
  if __name__ == "__main__":
566
  print("Launching the demo application.")
567
+ demo.launch(show_api=True)