Spaces:
Running
Running
import gradio as gr | |
from huggingface_hub import InferenceClient # Keep for direct use if needed, though agent will use its own model | |
import os | |
import json | |
import base64 | |
from PIL import Image | |
import io | |
# Smolagents imports | |
from smolagents import CodeAgent, Tool | |
from smolagents.models import InferenceClientModel as SmolInferenceClientModel | |
# We'll use PIL.Image directly for opening, AgentImage is for agent's internal typing if needed by a tool | |
from smolagents.gradio_ui import pull_messages_from_step # For formatting agent steps | |
from smolagents.memory import ActionStep, FinalAnswerStep, PlanningStep, MemoryStep # For type checking steps | |
from smolagents.models import ChatMessageStreamDelta # For type checking stream deltas | |
ACCESS_TOKEN = os.getenv("HF_TOKEN") | |
print("Access token loaded.") | |
# Function to encode image to base64 (remains useful if we ever need to pass base64 to a non-smolagent component) | |
def encode_image(image_path_or_pil): | |
if not image_path_or_pil: | |
print("No image path or PIL Image provided") | |
return None | |
try: | |
# print(f"Encoding image: {type(image_path_or_pil)}") # Debug | |
if isinstance(image_path_or_pil, Image.Image): | |
image = image_path_or_pil | |
else: # Assuming it's a path | |
image = Image.open(image_path_or_pil) | |
if image.mode == 'RGBA': | |
image = image.convert('RGB') | |
buffered = io.BytesIO() | |
image.save(buffered, format="JPEG") # JPEG is generally smaller for transfer | |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
# print("Image encoded successfully") # Debug | |
return img_str | |
except Exception as e: | |
print(f"Error encoding image: {e}") | |
return None | |
# This function will now set up and run the smolagent | |
def respond( | |
message_text, # Text from MultimodalTextbox | |
image_file_paths, # List of file paths from MultimodalTextbox | |
gradio_history: list[tuple[str, str]], # Gradio history (for context if needed, agent is stateless per call here) | |
system_message_for_agent, # System prompt for the main LLM agent | |
max_tokens, | |
temperature, | |
top_p, | |
frequency_penalty, | |
seed, | |
provider_for_agent_llm, | |
api_key_for_agent_llm, | |
model_id_for_agent_llm, | |
model_search_term, # Unused directly by agent logic | |
selected_model_for_agent_llm # Fallback model ID | |
): | |
print(f"Respond function called. Message: '{message_text}', Images: {image_file_paths}") | |
token_to_use = api_key_for_agent_llm if api_key_for_agent_llm.strip() != "" else ACCESS_TOKEN | |
model_to_use = model_id_for_agent_llm.strip() if model_id_for_agent_llm.strip() != "" else selected_model_for_agent_llm | |
# --- Initialize the LLM for the CodeAgent --- | |
agent_llm_params = { | |
"model_id": model_to_use, | |
"token": token_to_use, | |
# smolagents's InferenceClientModel uses max_tokens for max_new_tokens | |
"max_tokens": max_tokens, | |
"temperature": temperature if temperature > 0.01 else None, # Some models require temp > 0 | |
"top_p": top_p if top_p < 1.0 else None, # Often 1.0 means no top_p | |
"seed": seed if seed != -1 else None, | |
} | |
if provider_for_agent_llm and provider_for_agent_llm != "hf-inference": | |
agent_llm_params["provider"] = provider_for_agent_llm | |
# HFIC specific params, add if not default and supported | |
if frequency_penalty != 0.0: | |
agent_llm_params["frequency_penalty"] = frequency_penalty | |
agent_llm = SmolInferenceClientModel(**agent_llm_params) | |
print(f"Smolagents LLM for agent initialized: model='{model_to_use}', provider='{provider_for_agent_llm or 'default'}'") | |
# --- Define Tools for the Agent --- | |
agent_tools = [] | |
try: | |
image_gen_tool = Tool.from_space( | |
space_id="black-forest-labs/FLUX.1-schnell", | |
name="image_generator", | |
description="Generates an image from a textual prompt. Input is a single string argument named 'prompt'. Output is an image file path.", | |
token=token_to_use | |
) | |
agent_tools.append(image_gen_tool) | |
print("Image generation tool loaded: black-forest-labs/FLUX.1-schnell") | |
except Exception as e: | |
print(f"Error loading image generation tool: {e}") | |
yield f"Error: Could not load image generation tool. {e}" | |
return | |
# --- Initialize the CodeAgent --- | |
# If system_message_for_agent is empty, CodeAgent will use its default. | |
# The default is usually good as it explains how to use tools. | |
agent = CodeAgent( | |
tools=agent_tools, | |
model=agent_llm, | |
system_prompt=system_message_for_agent if system_message_for_agent and system_message_for_agent.strip() else None, | |
# add_base_tools=True, # Consider adding Python interpreter, etc. | |
stream_outputs=True # Important for Gradio streaming | |
) | |
print("Smolagents CodeAgent initialized.") | |
# --- Prepare task and image inputs for the agent --- | |
agent_task_text = message_text | |
pil_images_for_agent = [] | |
if image_file_paths: | |
for file_path in image_file_paths: | |
try: | |
pil_images_for_agent.append(Image.open(file_path)) | |
except Exception as e: | |
print(f"Error opening image file {file_path} for agent: {e}") | |
print(f"Agent task: '{agent_task_text}'") | |
if pil_images_for_agent: | |
print(f"Passing {len(pil_images_for_agent)} image(s) to agent.") | |
# --- Run the agent and stream response --- | |
# Agent is reset each turn. For conversational memory, agent instance | |
# would need to be stored in session_state and agent.run(..., reset=False) used. | |
current_agent_response_text = "" | |
try: | |
# The agent.run method returns a generator when stream=True | |
for step_item in agent.run( | |
task=agent_task_text, | |
images=pil_images_for_agent, | |
stream=True, | |
reset=True # Explicitly reset for stateless operation per call | |
): | |
if isinstance(step_item, ChatMessageStreamDelta): | |
if step_item.content: | |
current_agent_response_text += step_item.content | |
yield current_agent_response_text # Yield accumulated text | |
elif isinstance(step_item, (ActionStep, PlanningStep, FinalAnswerStep)): | |
# A structured step. Format it for Gradio. | |
# pull_messages_from_step yields gr.ChatMessage objects. | |
for gradio_chat_msg in pull_messages_from_step(step_item, skip_model_outputs=agent.stream_outputs): | |
# The 'bot' function will handle these gr.ChatMessage objects. | |
yield gradio_chat_msg # Yield the gr.ChatMessage object directly | |
current_agent_response_text = "" # Reset text buffer after a structured step | |
# else: | |
# print(f"Unhandled stream item type: {type(step_item)}") # Debug | |
# If there's any remaining text not part of a gr.ChatMessage, yield it. | |
# This usually shouldn't happen if stream_to_gradio logic is followed, | |
# as text deltas should be part of the last gr.ChatMessage or yielded before it. | |
# However, if the agent's final textual answer comes as pure deltas after all steps. | |
if current_agent_response_text and not isinstance(step_item, FinalAnswerStep): | |
# Check if the last yielded item already contains this text | |
if not (isinstance(step_item, gr.ChatMessage) and step_item.content == current_agent_response_text): | |
yield current_agent_response_text | |
except Exception as e: | |
error_message = f"Error during agent execution: {str(e)}" | |
print(error_message) | |
yield error_message # Yield the error message to be displayed in UI | |
print("Agent run completed.") | |
# Function to validate provider selection based on BYOK | |
def validate_provider(api_key, provider): | |
if not api_key.strip() and provider != "hf-inference": | |
return gr.update(value="hf-inference") | |
return gr.update(value=provider) | |
# GRADIO UI | |
with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo: | |
chatbot = gr.Chatbot( | |
height=600, | |
show_copy_button=True, | |
placeholder="Select a model and begin chatting. Now uses smolagents with tools!", | |
layout="panel", | |
bubble_full_width=False # For better display of images/files | |
) | |
print("Chatbot interface created.") | |
msg = gr.MultimodalTextbox( | |
placeholder="Type a message or upload images...", | |
show_label=False, | |
container=False, | |
scale=12, | |
file_types=["image"], | |
file_count="multiple", | |
sources=["upload"] | |
) | |
with gr.Accordion("Settings", open=False): | |
system_message_box = gr.Textbox( | |
value="You are a helpful AI assistant. You can generate images if asked. Be precise with your prompts for image generation.", | |
placeholder="You are a helpful AI assistant.", | |
label="System Prompt for Agent" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
max_tokens_slider = gr.Slider(minimum=1, maximum=4096, value=1024, step=1, label="Max New Tokens") | |
temperature_slider = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.01, label="Temperature") | |
top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.01, label="Top-P") | |
with gr.Column(): | |
frequency_penalty_slider = gr.Slider(minimum=-2.0, maximum=2.0, value=0.0, step=0.1, label="Frequency Penalty") | |
seed_slider = gr.Slider(minimum=-1, maximum=65535, value=-1, step=1, label="Seed (-1 for random)") | |
providers_list = [ | |
"hf-inference", "cerebras", "together", "sambanova", "novita", | |
"cohere", "fireworks-ai", "hyperbolic", "nebius", | |
] | |
provider_radio = gr.Radio(choices=providers_list, value="hf-inference", label="Inference Provider for Agent's LLM") | |
byok_textbox = gr.Textbox(value="", label="BYOK (Your HF Token or Provider API Key)", info="Enter API key for the selected provider. Uses HF_TOKEN if empty.", placeholder="Enter your API token", type="password") | |
custom_model_box = gr.Textbox(value="", label="Custom Model ID for Agent's LLM", info="(Optional) Provide a custom model ID. Overrides featured model.", placeholder="meta-llama/Llama-3.3-70B-Instruct") | |
model_search_box = gr.Textbox(label="Filter Featured Models", placeholder="Search for a featured model...", lines=1) | |
models_list = [ | |
"meta-llama/Llama-3.3-70B-Instruct", "meta-llama/Llama-3.1-70B-Instruct", "meta-llama/Llama-3.0-70B-Instruct", | |
"meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.2-3B-Instruct", "meta-llama/Llama-3.2-1B-Instruct", | |
"meta-llama/Llama-3.1-8B-Instruct", "NousResearch/Hermes-3-Llama-3.1-8B", "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", | |
"mistralai/Mistral-Nemo-Instruct-2407", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.3", | |
"Qwen/Qwen3-235B-A22B", "Qwen/Qwen3-32B", "Qwen/Qwen2.5-72B-Instruct", "Qwen/Qwen2.5-3B-Instruct", | |
"Qwen/Qwen2.5-Coder-32B-Instruct", "microsoft/Phi-3.5-mini-instruct", "microsoft/Phi-3-mini-128k-instruct", | |
] | |
featured_model_radio = gr.Radio(label="Select a Featured Model for Agent's LLM", choices=models_list, value="meta-llama/Llama-3.3-70B-Instruct", interactive=True) | |
gr.Markdown("[View all Text-to-Text models](https://huggingface.co/models?inference_provider=all&pipeline_tag=text-generation&sort=trending) | [View all multimodal models](https://huggingface.co/models?inference_provider=all&pipeline_tag=image-text-to-text&sort=trending)") | |
# Chat history state (using gr.State to manage it properly) | |
# The chatbot's value itself will be the history display. | |
# We might need a separate gr.State if agent needs to be conversational across turns. | |
# For now, agent is stateless per turn. | |
# Function for the chat interface | |
def user(user_multimodal_input_dict, history): | |
print(f"User input: {user_multimodal_input_dict}") | |
text_content = user_multimodal_input_dict.get("text", "") | |
files = user_multimodal_input_dict.get("files", []) | |
user_display_parts = [] | |
if text_content and text_content.strip(): | |
user_display_parts.append(text_content) | |
for file_path_obj in files: # file_path_obj is a tempfile._TemporaryFileWrapper | |
user_display_parts.append((file_path_obj.name, os.path.basename(file_path_obj.name))) | |
if not user_display_parts: | |
return history | |
# Append the user's multimodal message to history for display | |
# The actual data (dict) is passed to `bot` function separately. | |
history.append([user_display_parts if len(user_display_parts) > 1 else user_display_parts[0], None]) | |
return history | |
def bot(history, system_msg, max_tokens, temperature, top_p, freq_penalty, seed, provider, api_key, custom_model, search_term, selected_model): | |
if not history or not history[-1][0]: # If no user input | |
yield history | |
return | |
# The user's input (text and list of file paths) is in history[-1][0] | |
# If `user` function stores the dict: | |
raw_user_input_dict = history[-1][0] if isinstance(history[-1][0], dict) else {"text": str(history[-1][0]), "files": []} | |
# If `user` function stores formatted display parts: | |
# We need to reconstruct or rely on msg input to bot. | |
# For now, assuming msg.submit passes the raw dict. | |
# Let's adjust the Gradio flow to pass `msg` directly to `bot` as well. | |
# The `msg` variable in `msg.submit` holds the raw MultimodalTextbox output. | |
# We need to pass this raw dict to `respond`. | |
# The `history` is for display. | |
# This part is tricky as `bot` gets `history` which is already formatted for display. | |
# A common pattern is to pass `msg` (raw input) also to `bot`. | |
# Let's assume `history[-1][0]` contains enough info or we adjust `user` fn. | |
# For simplicity, let's assume `user` stores the raw dict if needed, | |
# or `bot` can parse `history[-1][0]` if it's a string/list of tuples. | |
# Let's assume `history[-1][0]` is the raw `user_multimodal_input_dict` | |
# This means the `user` function must append it like: `history.append([user_multimodal_input_dict, None])` | |
# And the chatbot will display `str(user_multimodal_input_dict)`. | |
# This is what the current `user` function does. | |
user_input_data = history[-1][0] # This should be the dict from MultimodalTextbox | |
text_input_for_agent = user_input_data.get("text", "") | |
# Files from MultimodalTextbox are temp file paths | |
image_file_paths_for_agent = [f.name for f in user_input_data.get("files", []) if hasattr(f, 'name')] | |
history[-1][1] = "" # Initialize assistant's part for streaming | |
# Buffer for current text stream from agent | |
# Handles both pure text deltas and text content from gr.ChatMessage | |
current_text_for_turn = "" | |
for item in respond( | |
message_text=text_input_for_agent, | |
image_file_paths=image_file_paths_for_agent, | |
gradio_history=history[:-1], # Pass previous turns for context if agent uses it | |
system_message_for_agent=system_msg, | |
max_tokens=max_tokens, temperature=temperature, top_p=top_p, | |
frequency_penalty=freq_penalty, seed=seed, | |
provider_for_agent_llm=provider, api_key_for_agent_llm=api_key, | |
model_id_for_agent_llm=custom_model, | |
model_search_term=search_term, # unused | |
selected_model_for_agent_llm=selected_model | |
): | |
if isinstance(item, str): # LLM text delta from agent's thought or textual answer | |
current_text_for_turn = item | |
history[-1][1] = current_text_for_turn | |
elif isinstance(item, gr.ChatMessage): | |
# This is a structured step (thought, tool output, image, etc.) | |
# We need to append this to the history as a new message or part of current message. | |
# For simplicity, let's append its string content to the current turn's assistant message. | |
# If it's an image/file, we'll represent it as a markdown link. | |
if isinstance(item.content, str): | |
current_text_for_turn = item.content # Replace if it's a full message | |
elif isinstance(item.content, dict) and "path" in item.content: | |
# This is typically an image or audio file | |
file_path = item.content["path"] | |
# We need to make this file accessible to Gradio if it's temporary from agent | |
# For now, just put a placeholder. | |
# If it's an output from a tool, the path might be relative to where smolagents saves it. | |
# Gradio needs an absolute path or a URL. | |
# A common pattern is to copy temp files to a static dir served by Gradio or use gr.File. | |
# For now, let's assume Gradio can handle local paths if they are in a folder it knows. | |
# We'll display it as a tuple for Gradio Chatbot. | |
# This means history[-1][1] needs to become a list. | |
# If current_text_for_turn is not empty, make history[-1][1] a list | |
if current_text_for_turn and not isinstance(history[-1][1], list): | |
history[-1][1] = [current_text_for_turn] | |
elif not current_text_for_turn and not isinstance(history[-1][1], list): | |
history[-1][1] = [] | |
alt_text = item.metadata.get("title", os.path.basename(file_path)) if item.metadata else os.path.basename(file_path) | |
# Add as new component to the list for current assistant message | |
if isinstance(history[-1][1], list): | |
history[-1][1].append((file_path, alt_text)) | |
else: # Should have been made a list above | |
history[-1][1] = [(file_path, alt_text)] | |
current_text_for_turn = "" # Reset text buffer after a file | |
# If it's not a delta, but a full message, replace the current text | |
if not isinstance(history[-1][1], list): # if it hasn't become a list due to file | |
history[-1][1] = current_text_for_turn | |
yield history | |
# Event handlers | |
# `msg.submit`'s first argument is the function to call. | |
# Its `inputs` are the Gradio components whose values are passed to the function. | |
# Its `outputs` are the Gradio components that are updated by the function's return value. | |
# The `user` function now appends the raw dict from MultimodalTextbox to history. | |
# The `bot` function takes this history. | |
# When msg is submitted: | |
# 1. Call `user` to update history with user's input. Output is `chatbot`. | |
# 2. Then call `bot` with the updated history. Output is `chatbot`. | |
# 3. Then clear `msg` | |
msg.submit( | |
user, | |
[msg, chatbot], | |
[chatbot], # `user` returns the new history, updating the chatbot display | |
queue=False | |
).then( | |
bot, | |
[chatbot, system_message_box, max_tokens_slider, temperature_slider, top_p_slider, | |
frequency_penalty_slider, seed_slider, provider_radio, byok_textbox, custom_model_box, | |
model_search_box, featured_model_radio], | |
[chatbot] # `bot` yields history updates, streaming to chatbot | |
).then( | |
lambda: {"text": "", "files": []}, # Clear MultimodalTextbox | |
None, | |
[msg] | |
) | |
model_search_box.change(fn=filter_models, inputs=model_search_box, outputs=featured_model_radio) | |
featured_model_radio.change(fn=set_custom_model_from_radio, inputs=featured_model_radio, outputs=custom_model_box) | |
byok_textbox.change(fn=validate_provider, inputs=[byok_textbox, provider_radio], outputs=provider_radio) | |
provider_radio.change(fn=validate_provider, inputs=[byok_textbox, provider_radio], outputs=provider_radio) | |
print("Gradio interface initialized.") | |
if __name__ == "__main__": | |
print("Launching the demo application.") | |
demo.launch(show_api=False) # show_api=False for cleaner launch, True for API docs |