Spaces:
Running
Running
File size: 20,875 Bytes
cb919f0 81b2233 c5a20a4 ea82e64 cb919f0 81b2233 8d2c697 cb919f0 717cd1f cb919f0 81b2233 81286e1 81b2233 717cd1f 81b2233 81286e1 81b2233 81286e1 81b2233 81286e1 81b2233 cb919f0 81b2233 717cd1f 81b2233 cb919f0 81b2233 4db9e4f 81b2233 4db9e4f 81b2233 4fa442d 81b2233 4fa442d 4db9e4f 81b2233 4db9e4f 81b2233 4db9e4f 81b2233 4db9e4f 81b2233 a7fbaae 81286e1 717cd1f 81b2233 a7fbaae 81b2233 6f66243 81b2233 a7fbaae 81286e1 81b2233 717cd1f cb919f0 81286e1 81b2233 8d2c697 81b2233 cb919f0 717cd1f 81286e1 a7fbaae 4fa442d a7fbaae 717cd1f 81b2233 81286e1 cb919f0 dc27384 81b2233 dc27384 4fa442d a7fbaae 81b2233 6f66243 717cd1f 81b2233 4fa442d 81b2233 a7fbaae 81b2233 a7fbaae 717cd1f 81b2233 4fa442d 81b2233 4fa442d 81b2233 4fa442d 81b2233 8d2c697 81b2233 4db9e4f 81b2233 8d2c697 81b2233 8d2c697 81b2233 4fa442d 81b2233 8d2c697 81b2233 8d2c697 81b2233 8d2c697 81b2233 717cd1f 4db9e4f 4fa442d 81b2233 4db9e4f 717cd1f 4db9e4f 81b2233 717cd1f 81b2233 717cd1f a7fbaae 717cd1f 4fa442d cb919f0 717cd1f cb919f0 717cd1f 81b2233 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 |
import gradio as gr
from huggingface_hub import InferenceClient # Keep for direct use if needed, though agent will use its own model
import os
import json
import base64
from PIL import Image
import io
# Smolagents imports
from smolagents import CodeAgent, Tool
from smolagents.models import InferenceClientModel as SmolInferenceClientModel
# We'll use PIL.Image directly for opening, AgentImage is for agent's internal typing if needed by a tool
from smolagents.gradio_ui import pull_messages_from_step # For formatting agent steps
from smolagents.memory import ActionStep, FinalAnswerStep, PlanningStep, MemoryStep # For type checking steps
from smolagents.models import ChatMessageStreamDelta # For type checking stream deltas
ACCESS_TOKEN = os.getenv("HF_TOKEN")
print("Access token loaded.")
# Function to encode image to base64 (remains useful if we ever need to pass base64 to a non-smolagent component)
def encode_image(image_path_or_pil):
if not image_path_or_pil:
print("No image path or PIL Image provided")
return None
try:
# print(f"Encoding image: {type(image_path_or_pil)}") # Debug
if isinstance(image_path_or_pil, Image.Image):
image = image_path_or_pil
else: # Assuming it's a path
image = Image.open(image_path_or_pil)
if image.mode == 'RGBA':
image = image.convert('RGB')
buffered = io.BytesIO()
image.save(buffered, format="JPEG") # JPEG is generally smaller for transfer
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
# print("Image encoded successfully") # Debug
return img_str
except Exception as e:
print(f"Error encoding image: {e}")
return None
# This function will now set up and run the smolagent
def respond(
message_text, # Text from MultimodalTextbox
image_file_paths, # List of file paths from MultimodalTextbox
gradio_history: list[tuple[str, str]], # Gradio history (for context if needed, agent is stateless per call here)
system_message_for_agent, # System prompt for the main LLM agent
max_tokens,
temperature,
top_p,
frequency_penalty,
seed,
provider_for_agent_llm,
api_key_for_agent_llm,
model_id_for_agent_llm,
model_search_term, # Unused directly by agent logic
selected_model_for_agent_llm # Fallback model ID
):
print(f"Respond function called. Message: '{message_text}', Images: {image_file_paths}")
token_to_use = api_key_for_agent_llm if api_key_for_agent_llm.strip() != "" else ACCESS_TOKEN
model_to_use = model_id_for_agent_llm.strip() if model_id_for_agent_llm.strip() != "" else selected_model_for_agent_llm
# --- Initialize the LLM for the CodeAgent ---
agent_llm_params = {
"model_id": model_to_use,
"token": token_to_use,
# smolagents's InferenceClientModel uses max_tokens for max_new_tokens
"max_tokens": max_tokens,
"temperature": temperature if temperature > 0.01 else None, # Some models require temp > 0
"top_p": top_p if top_p < 1.0 else None, # Often 1.0 means no top_p
"seed": seed if seed != -1 else None,
}
if provider_for_agent_llm and provider_for_agent_llm != "hf-inference":
agent_llm_params["provider"] = provider_for_agent_llm
# HFIC specific params, add if not default and supported
if frequency_penalty != 0.0:
agent_llm_params["frequency_penalty"] = frequency_penalty
agent_llm = SmolInferenceClientModel(**agent_llm_params)
print(f"Smolagents LLM for agent initialized: model='{model_to_use}', provider='{provider_for_agent_llm or 'default'}'")
# --- Define Tools for the Agent ---
agent_tools = []
try:
image_gen_tool = Tool.from_space(
space_id="black-forest-labs/FLUX.1-schnell",
name="image_generator",
description="Generates an image from a textual prompt. Input is a single string argument named 'prompt'. Output is an image file path.",
token=token_to_use
)
agent_tools.append(image_gen_tool)
print("Image generation tool loaded: black-forest-labs/FLUX.1-schnell")
except Exception as e:
print(f"Error loading image generation tool: {e}")
yield f"Error: Could not load image generation tool. {e}"
return
# --- Initialize the CodeAgent ---
# If system_message_for_agent is empty, CodeAgent will use its default.
# The default is usually good as it explains how to use tools.
agent = CodeAgent(
tools=agent_tools,
model=agent_llm,
system_prompt=system_message_for_agent if system_message_for_agent and system_message_for_agent.strip() else None,
# add_base_tools=True, # Consider adding Python interpreter, etc.
stream_outputs=True # Important for Gradio streaming
)
print("Smolagents CodeAgent initialized.")
# --- Prepare task and image inputs for the agent ---
agent_task_text = message_text
pil_images_for_agent = []
if image_file_paths:
for file_path in image_file_paths:
try:
pil_images_for_agent.append(Image.open(file_path))
except Exception as e:
print(f"Error opening image file {file_path} for agent: {e}")
print(f"Agent task: '{agent_task_text}'")
if pil_images_for_agent:
print(f"Passing {len(pil_images_for_agent)} image(s) to agent.")
# --- Run the agent and stream response ---
# Agent is reset each turn. For conversational memory, agent instance
# would need to be stored in session_state and agent.run(..., reset=False) used.
current_agent_response_text = ""
try:
# The agent.run method returns a generator when stream=True
for step_item in agent.run(
task=agent_task_text,
images=pil_images_for_agent,
stream=True,
reset=True # Explicitly reset for stateless operation per call
):
if isinstance(step_item, ChatMessageStreamDelta):
if step_item.content:
current_agent_response_text += step_item.content
yield current_agent_response_text # Yield accumulated text
elif isinstance(step_item, (ActionStep, PlanningStep, FinalAnswerStep)):
# A structured step. Format it for Gradio.
# pull_messages_from_step yields gr.ChatMessage objects.
for gradio_chat_msg in pull_messages_from_step(step_item, skip_model_outputs=agent.stream_outputs):
# The 'bot' function will handle these gr.ChatMessage objects.
yield gradio_chat_msg # Yield the gr.ChatMessage object directly
current_agent_response_text = "" # Reset text buffer after a structured step
# else:
# print(f"Unhandled stream item type: {type(step_item)}") # Debug
# If there's any remaining text not part of a gr.ChatMessage, yield it.
# This usually shouldn't happen if stream_to_gradio logic is followed,
# as text deltas should be part of the last gr.ChatMessage or yielded before it.
# However, if the agent's final textual answer comes as pure deltas after all steps.
if current_agent_response_text and not isinstance(step_item, FinalAnswerStep):
# Check if the last yielded item already contains this text
if not (isinstance(step_item, gr.ChatMessage) and step_item.content == current_agent_response_text):
yield current_agent_response_text
except Exception as e:
error_message = f"Error during agent execution: {str(e)}"
print(error_message)
yield error_message # Yield the error message to be displayed in UI
print("Agent run completed.")
# Function to validate provider selection based on BYOK
def validate_provider(api_key, provider):
if not api_key.strip() and provider != "hf-inference":
return gr.update(value="hf-inference")
return gr.update(value=provider)
# GRADIO UI
with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
chatbot = gr.Chatbot(
height=600,
show_copy_button=True,
placeholder="Select a model and begin chatting. Now uses smolagents with tools!",
layout="panel",
bubble_full_width=False # For better display of images/files
)
print("Chatbot interface created.")
msg = gr.MultimodalTextbox(
placeholder="Type a message or upload images...",
show_label=False,
container=False,
scale=12,
file_types=["image"],
file_count="multiple",
sources=["upload"]
)
with gr.Accordion("Settings", open=False):
system_message_box = gr.Textbox(
value="You are a helpful AI assistant. You can generate images if asked. Be precise with your prompts for image generation.",
placeholder="You are a helpful AI assistant.",
label="System Prompt for Agent"
)
with gr.Row():
with gr.Column():
max_tokens_slider = gr.Slider(minimum=1, maximum=4096, value=1024, step=1, label="Max New Tokens")
temperature_slider = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.01, label="Temperature")
top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.01, label="Top-P")
with gr.Column():
frequency_penalty_slider = gr.Slider(minimum=-2.0, maximum=2.0, value=0.0, step=0.1, label="Frequency Penalty")
seed_slider = gr.Slider(minimum=-1, maximum=65535, value=-1, step=1, label="Seed (-1 for random)")
providers_list = [
"hf-inference", "cerebras", "together", "sambanova", "novita",
"cohere", "fireworks-ai", "hyperbolic", "nebius",
]
provider_radio = gr.Radio(choices=providers_list, value="hf-inference", label="Inference Provider for Agent's LLM")
byok_textbox = gr.Textbox(value="", label="BYOK (Your HF Token or Provider API Key)", info="Enter API key for the selected provider. Uses HF_TOKEN if empty.", placeholder="Enter your API token", type="password")
custom_model_box = gr.Textbox(value="", label="Custom Model ID for Agent's LLM", info="(Optional) Provide a custom model ID. Overrides featured model.", placeholder="meta-llama/Llama-3.3-70B-Instruct")
model_search_box = gr.Textbox(label="Filter Featured Models", placeholder="Search for a featured model...", lines=1)
models_list = [
"meta-llama/Llama-3.3-70B-Instruct", "meta-llama/Llama-3.1-70B-Instruct", "meta-llama/Llama-3.0-70B-Instruct",
"meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.2-3B-Instruct", "meta-llama/Llama-3.2-1B-Instruct",
"meta-llama/Llama-3.1-8B-Instruct", "NousResearch/Hermes-3-Llama-3.1-8B", "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
"mistralai/Mistral-Nemo-Instruct-2407", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.3",
"Qwen/Qwen3-235B-A22B", "Qwen/Qwen3-32B", "Qwen/Qwen2.5-72B-Instruct", "Qwen/Qwen2.5-3B-Instruct",
"Qwen/Qwen2.5-Coder-32B-Instruct", "microsoft/Phi-3.5-mini-instruct", "microsoft/Phi-3-mini-128k-instruct",
]
featured_model_radio = gr.Radio(label="Select a Featured Model for Agent's LLM", choices=models_list, value="meta-llama/Llama-3.3-70B-Instruct", interactive=True)
gr.Markdown("[View all Text-to-Text models](https://huggingface.co/models?inference_provider=all&pipeline_tag=text-generation&sort=trending) | [View all multimodal models](https://huggingface.co/models?inference_provider=all&pipeline_tag=image-text-to-text&sort=trending)")
# Chat history state (using gr.State to manage it properly)
# The chatbot's value itself will be the history display.
# We might need a separate gr.State if agent needs to be conversational across turns.
# For now, agent is stateless per turn.
# Function for the chat interface
def user(user_multimodal_input_dict, history):
print(f"User input: {user_multimodal_input_dict}")
text_content = user_multimodal_input_dict.get("text", "")
files = user_multimodal_input_dict.get("files", [])
user_display_parts = []
if text_content and text_content.strip():
user_display_parts.append(text_content)
for file_path_obj in files: # file_path_obj is a tempfile._TemporaryFileWrapper
user_display_parts.append((file_path_obj.name, os.path.basename(file_path_obj.name)))
if not user_display_parts:
return history
# Append the user's multimodal message to history for display
# The actual data (dict) is passed to `bot` function separately.
history.append([user_display_parts if len(user_display_parts) > 1 else user_display_parts[0], None])
return history
def bot(history, system_msg, max_tokens, temperature, top_p, freq_penalty, seed, provider, api_key, custom_model, search_term, selected_model):
if not history or not history[-1][0]: # If no user input
yield history
return
# The user's input (text and list of file paths) is in history[-1][0]
# If `user` function stores the dict:
raw_user_input_dict = history[-1][0] if isinstance(history[-1][0], dict) else {"text": str(history[-1][0]), "files": []}
# If `user` function stores formatted display parts:
# We need to reconstruct or rely on msg input to bot.
# For now, assuming msg.submit passes the raw dict.
# Let's adjust the Gradio flow to pass `msg` directly to `bot` as well.
# The `msg` variable in `msg.submit` holds the raw MultimodalTextbox output.
# We need to pass this raw dict to `respond`.
# The `history` is for display.
# This part is tricky as `bot` gets `history` which is already formatted for display.
# A common pattern is to pass `msg` (raw input) also to `bot`.
# Let's assume `history[-1][0]` contains enough info or we adjust `user` fn.
# For simplicity, let's assume `user` stores the raw dict if needed,
# or `bot` can parse `history[-1][0]` if it's a string/list of tuples.
# Let's assume `history[-1][0]` is the raw `user_multimodal_input_dict`
# This means the `user` function must append it like: `history.append([user_multimodal_input_dict, None])`
# And the chatbot will display `str(user_multimodal_input_dict)`.
# This is what the current `user` function does.
user_input_data = history[-1][0] # This should be the dict from MultimodalTextbox
text_input_for_agent = user_input_data.get("text", "")
# Files from MultimodalTextbox are temp file paths
image_file_paths_for_agent = [f.name for f in user_input_data.get("files", []) if hasattr(f, 'name')]
history[-1][1] = "" # Initialize assistant's part for streaming
# Buffer for current text stream from agent
# Handles both pure text deltas and text content from gr.ChatMessage
current_text_for_turn = ""
for item in respond(
message_text=text_input_for_agent,
image_file_paths=image_file_paths_for_agent,
gradio_history=history[:-1], # Pass previous turns for context if agent uses it
system_message_for_agent=system_msg,
max_tokens=max_tokens, temperature=temperature, top_p=top_p,
frequency_penalty=freq_penalty, seed=seed,
provider_for_agent_llm=provider, api_key_for_agent_llm=api_key,
model_id_for_agent_llm=custom_model,
model_search_term=search_term, # unused
selected_model_for_agent_llm=selected_model
):
if isinstance(item, str): # LLM text delta from agent's thought or textual answer
current_text_for_turn = item
history[-1][1] = current_text_for_turn
elif isinstance(item, gr.ChatMessage):
# This is a structured step (thought, tool output, image, etc.)
# We need to append this to the history as a new message or part of current message.
# For simplicity, let's append its string content to the current turn's assistant message.
# If it's an image/file, we'll represent it as a markdown link.
if isinstance(item.content, str):
current_text_for_turn = item.content # Replace if it's a full message
elif isinstance(item.content, dict) and "path" in item.content:
# This is typically an image or audio file
file_path = item.content["path"]
# We need to make this file accessible to Gradio if it's temporary from agent
# For now, just put a placeholder.
# If it's an output from a tool, the path might be relative to where smolagents saves it.
# Gradio needs an absolute path or a URL.
# A common pattern is to copy temp files to a static dir served by Gradio or use gr.File.
# For now, let's assume Gradio can handle local paths if they are in a folder it knows.
# We'll display it as a tuple for Gradio Chatbot.
# This means history[-1][1] needs to become a list.
# If current_text_for_turn is not empty, make history[-1][1] a list
if current_text_for_turn and not isinstance(history[-1][1], list):
history[-1][1] = [current_text_for_turn]
elif not current_text_for_turn and not isinstance(history[-1][1], list):
history[-1][1] = []
alt_text = item.metadata.get("title", os.path.basename(file_path)) if item.metadata else os.path.basename(file_path)
# Add as new component to the list for current assistant message
if isinstance(history[-1][1], list):
history[-1][1].append((file_path, alt_text))
else: # Should have been made a list above
history[-1][1] = [(file_path, alt_text)]
current_text_for_turn = "" # Reset text buffer after a file
# If it's not a delta, but a full message, replace the current text
if not isinstance(history[-1][1], list): # if it hasn't become a list due to file
history[-1][1] = current_text_for_turn
yield history
# Event handlers
# `msg.submit`'s first argument is the function to call.
# Its `inputs` are the Gradio components whose values are passed to the function.
# Its `outputs` are the Gradio components that are updated by the function's return value.
# The `user` function now appends the raw dict from MultimodalTextbox to history.
# The `bot` function takes this history.
# When msg is submitted:
# 1. Call `user` to update history with user's input. Output is `chatbot`.
# 2. Then call `bot` with the updated history. Output is `chatbot`.
# 3. Then clear `msg`
msg.submit(
user,
[msg, chatbot],
[chatbot], # `user` returns the new history, updating the chatbot display
queue=False
).then(
bot,
[chatbot, system_message_box, max_tokens_slider, temperature_slider, top_p_slider,
frequency_penalty_slider, seed_slider, provider_radio, byok_textbox, custom_model_box,
model_search_box, featured_model_radio],
[chatbot] # `bot` yields history updates, streaming to chatbot
).then(
lambda: {"text": "", "files": []}, # Clear MultimodalTextbox
None,
[msg]
)
model_search_box.change(fn=filter_models, inputs=model_search_box, outputs=featured_model_radio)
featured_model_radio.change(fn=set_custom_model_from_radio, inputs=featured_model_radio, outputs=custom_model_box)
byok_textbox.change(fn=validate_provider, inputs=[byok_textbox, provider_radio], outputs=provider_radio)
provider_radio.change(fn=validate_provider, inputs=[byok_textbox, provider_radio], outputs=provider_radio)
print("Gradio interface initialized.")
if __name__ == "__main__":
print("Launching the demo application.")
demo.launch(show_api=False) # show_api=False for cleaner launch, True for API docs |