Nymbo's picture
Update app.py
81b2233 verified
raw
history blame
20.9 kB
import gradio as gr
from huggingface_hub import InferenceClient # Keep for direct use if needed, though agent will use its own model
import os
import json
import base64
from PIL import Image
import io
# Smolagents imports
from smolagents import CodeAgent, Tool
from smolagents.models import InferenceClientModel as SmolInferenceClientModel
# We'll use PIL.Image directly for opening, AgentImage is for agent's internal typing if needed by a tool
from smolagents.gradio_ui import pull_messages_from_step # For formatting agent steps
from smolagents.memory import ActionStep, FinalAnswerStep, PlanningStep, MemoryStep # For type checking steps
from smolagents.models import ChatMessageStreamDelta # For type checking stream deltas
ACCESS_TOKEN = os.getenv("HF_TOKEN")
print("Access token loaded.")
# Function to encode image to base64 (remains useful if we ever need to pass base64 to a non-smolagent component)
def encode_image(image_path_or_pil):
if not image_path_or_pil:
print("No image path or PIL Image provided")
return None
try:
# print(f"Encoding image: {type(image_path_or_pil)}") # Debug
if isinstance(image_path_or_pil, Image.Image):
image = image_path_or_pil
else: # Assuming it's a path
image = Image.open(image_path_or_pil)
if image.mode == 'RGBA':
image = image.convert('RGB')
buffered = io.BytesIO()
image.save(buffered, format="JPEG") # JPEG is generally smaller for transfer
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
# print("Image encoded successfully") # Debug
return img_str
except Exception as e:
print(f"Error encoding image: {e}")
return None
# This function will now set up and run the smolagent
def respond(
message_text, # Text from MultimodalTextbox
image_file_paths, # List of file paths from MultimodalTextbox
gradio_history: list[tuple[str, str]], # Gradio history (for context if needed, agent is stateless per call here)
system_message_for_agent, # System prompt for the main LLM agent
max_tokens,
temperature,
top_p,
frequency_penalty,
seed,
provider_for_agent_llm,
api_key_for_agent_llm,
model_id_for_agent_llm,
model_search_term, # Unused directly by agent logic
selected_model_for_agent_llm # Fallback model ID
):
print(f"Respond function called. Message: '{message_text}', Images: {image_file_paths}")
token_to_use = api_key_for_agent_llm if api_key_for_agent_llm.strip() != "" else ACCESS_TOKEN
model_to_use = model_id_for_agent_llm.strip() if model_id_for_agent_llm.strip() != "" else selected_model_for_agent_llm
# --- Initialize the LLM for the CodeAgent ---
agent_llm_params = {
"model_id": model_to_use,
"token": token_to_use,
# smolagents's InferenceClientModel uses max_tokens for max_new_tokens
"max_tokens": max_tokens,
"temperature": temperature if temperature > 0.01 else None, # Some models require temp > 0
"top_p": top_p if top_p < 1.0 else None, # Often 1.0 means no top_p
"seed": seed if seed != -1 else None,
}
if provider_for_agent_llm and provider_for_agent_llm != "hf-inference":
agent_llm_params["provider"] = provider_for_agent_llm
# HFIC specific params, add if not default and supported
if frequency_penalty != 0.0:
agent_llm_params["frequency_penalty"] = frequency_penalty
agent_llm = SmolInferenceClientModel(**agent_llm_params)
print(f"Smolagents LLM for agent initialized: model='{model_to_use}', provider='{provider_for_agent_llm or 'default'}'")
# --- Define Tools for the Agent ---
agent_tools = []
try:
image_gen_tool = Tool.from_space(
space_id="black-forest-labs/FLUX.1-schnell",
name="image_generator",
description="Generates an image from a textual prompt. Input is a single string argument named 'prompt'. Output is an image file path.",
token=token_to_use
)
agent_tools.append(image_gen_tool)
print("Image generation tool loaded: black-forest-labs/FLUX.1-schnell")
except Exception as e:
print(f"Error loading image generation tool: {e}")
yield f"Error: Could not load image generation tool. {e}"
return
# --- Initialize the CodeAgent ---
# If system_message_for_agent is empty, CodeAgent will use its default.
# The default is usually good as it explains how to use tools.
agent = CodeAgent(
tools=agent_tools,
model=agent_llm,
system_prompt=system_message_for_agent if system_message_for_agent and system_message_for_agent.strip() else None,
# add_base_tools=True, # Consider adding Python interpreter, etc.
stream_outputs=True # Important for Gradio streaming
)
print("Smolagents CodeAgent initialized.")
# --- Prepare task and image inputs for the agent ---
agent_task_text = message_text
pil_images_for_agent = []
if image_file_paths:
for file_path in image_file_paths:
try:
pil_images_for_agent.append(Image.open(file_path))
except Exception as e:
print(f"Error opening image file {file_path} for agent: {e}")
print(f"Agent task: '{agent_task_text}'")
if pil_images_for_agent:
print(f"Passing {len(pil_images_for_agent)} image(s) to agent.")
# --- Run the agent and stream response ---
# Agent is reset each turn. For conversational memory, agent instance
# would need to be stored in session_state and agent.run(..., reset=False) used.
current_agent_response_text = ""
try:
# The agent.run method returns a generator when stream=True
for step_item in agent.run(
task=agent_task_text,
images=pil_images_for_agent,
stream=True,
reset=True # Explicitly reset for stateless operation per call
):
if isinstance(step_item, ChatMessageStreamDelta):
if step_item.content:
current_agent_response_text += step_item.content
yield current_agent_response_text # Yield accumulated text
elif isinstance(step_item, (ActionStep, PlanningStep, FinalAnswerStep)):
# A structured step. Format it for Gradio.
# pull_messages_from_step yields gr.ChatMessage objects.
for gradio_chat_msg in pull_messages_from_step(step_item, skip_model_outputs=agent.stream_outputs):
# The 'bot' function will handle these gr.ChatMessage objects.
yield gradio_chat_msg # Yield the gr.ChatMessage object directly
current_agent_response_text = "" # Reset text buffer after a structured step
# else:
# print(f"Unhandled stream item type: {type(step_item)}") # Debug
# If there's any remaining text not part of a gr.ChatMessage, yield it.
# This usually shouldn't happen if stream_to_gradio logic is followed,
# as text deltas should be part of the last gr.ChatMessage or yielded before it.
# However, if the agent's final textual answer comes as pure deltas after all steps.
if current_agent_response_text and not isinstance(step_item, FinalAnswerStep):
# Check if the last yielded item already contains this text
if not (isinstance(step_item, gr.ChatMessage) and step_item.content == current_agent_response_text):
yield current_agent_response_text
except Exception as e:
error_message = f"Error during agent execution: {str(e)}"
print(error_message)
yield error_message # Yield the error message to be displayed in UI
print("Agent run completed.")
# Function to validate provider selection based on BYOK
def validate_provider(api_key, provider):
if not api_key.strip() and provider != "hf-inference":
return gr.update(value="hf-inference")
return gr.update(value=provider)
# GRADIO UI
with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
chatbot = gr.Chatbot(
height=600,
show_copy_button=True,
placeholder="Select a model and begin chatting. Now uses smolagents with tools!",
layout="panel",
bubble_full_width=False # For better display of images/files
)
print("Chatbot interface created.")
msg = gr.MultimodalTextbox(
placeholder="Type a message or upload images...",
show_label=False,
container=False,
scale=12,
file_types=["image"],
file_count="multiple",
sources=["upload"]
)
with gr.Accordion("Settings", open=False):
system_message_box = gr.Textbox(
value="You are a helpful AI assistant. You can generate images if asked. Be precise with your prompts for image generation.",
placeholder="You are a helpful AI assistant.",
label="System Prompt for Agent"
)
with gr.Row():
with gr.Column():
max_tokens_slider = gr.Slider(minimum=1, maximum=4096, value=1024, step=1, label="Max New Tokens")
temperature_slider = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.01, label="Temperature")
top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.01, label="Top-P")
with gr.Column():
frequency_penalty_slider = gr.Slider(minimum=-2.0, maximum=2.0, value=0.0, step=0.1, label="Frequency Penalty")
seed_slider = gr.Slider(minimum=-1, maximum=65535, value=-1, step=1, label="Seed (-1 for random)")
providers_list = [
"hf-inference", "cerebras", "together", "sambanova", "novita",
"cohere", "fireworks-ai", "hyperbolic", "nebius",
]
provider_radio = gr.Radio(choices=providers_list, value="hf-inference", label="Inference Provider for Agent's LLM")
byok_textbox = gr.Textbox(value="", label="BYOK (Your HF Token or Provider API Key)", info="Enter API key for the selected provider. Uses HF_TOKEN if empty.", placeholder="Enter your API token", type="password")
custom_model_box = gr.Textbox(value="", label="Custom Model ID for Agent's LLM", info="(Optional) Provide a custom model ID. Overrides featured model.", placeholder="meta-llama/Llama-3.3-70B-Instruct")
model_search_box = gr.Textbox(label="Filter Featured Models", placeholder="Search for a featured model...", lines=1)
models_list = [
"meta-llama/Llama-3.3-70B-Instruct", "meta-llama/Llama-3.1-70B-Instruct", "meta-llama/Llama-3.0-70B-Instruct",
"meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.2-3B-Instruct", "meta-llama/Llama-3.2-1B-Instruct",
"meta-llama/Llama-3.1-8B-Instruct", "NousResearch/Hermes-3-Llama-3.1-8B", "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
"mistralai/Mistral-Nemo-Instruct-2407", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.3",
"Qwen/Qwen3-235B-A22B", "Qwen/Qwen3-32B", "Qwen/Qwen2.5-72B-Instruct", "Qwen/Qwen2.5-3B-Instruct",
"Qwen/Qwen2.5-Coder-32B-Instruct", "microsoft/Phi-3.5-mini-instruct", "microsoft/Phi-3-mini-128k-instruct",
]
featured_model_radio = gr.Radio(label="Select a Featured Model for Agent's LLM", choices=models_list, value="meta-llama/Llama-3.3-70B-Instruct", interactive=True)
gr.Markdown("[View all Text-to-Text models](https://huggingface.co/models?inference_provider=all&pipeline_tag=text-generation&sort=trending) | [View all multimodal models](https://huggingface.co/models?inference_provider=all&pipeline_tag=image-text-to-text&sort=trending)")
# Chat history state (using gr.State to manage it properly)
# The chatbot's value itself will be the history display.
# We might need a separate gr.State if agent needs to be conversational across turns.
# For now, agent is stateless per turn.
# Function for the chat interface
def user(user_multimodal_input_dict, history):
print(f"User input: {user_multimodal_input_dict}")
text_content = user_multimodal_input_dict.get("text", "")
files = user_multimodal_input_dict.get("files", [])
user_display_parts = []
if text_content and text_content.strip():
user_display_parts.append(text_content)
for file_path_obj in files: # file_path_obj is a tempfile._TemporaryFileWrapper
user_display_parts.append((file_path_obj.name, os.path.basename(file_path_obj.name)))
if not user_display_parts:
return history
# Append the user's multimodal message to history for display
# The actual data (dict) is passed to `bot` function separately.
history.append([user_display_parts if len(user_display_parts) > 1 else user_display_parts[0], None])
return history
def bot(history, system_msg, max_tokens, temperature, top_p, freq_penalty, seed, provider, api_key, custom_model, search_term, selected_model):
if not history or not history[-1][0]: # If no user input
yield history
return
# The user's input (text and list of file paths) is in history[-1][0]
# If `user` function stores the dict:
raw_user_input_dict = history[-1][0] if isinstance(history[-1][0], dict) else {"text": str(history[-1][0]), "files": []}
# If `user` function stores formatted display parts:
# We need to reconstruct or rely on msg input to bot.
# For now, assuming msg.submit passes the raw dict.
# Let's adjust the Gradio flow to pass `msg` directly to `bot` as well.
# The `msg` variable in `msg.submit` holds the raw MultimodalTextbox output.
# We need to pass this raw dict to `respond`.
# The `history` is for display.
# This part is tricky as `bot` gets `history` which is already formatted for display.
# A common pattern is to pass `msg` (raw input) also to `bot`.
# Let's assume `history[-1][0]` contains enough info or we adjust `user` fn.
# For simplicity, let's assume `user` stores the raw dict if needed,
# or `bot` can parse `history[-1][0]` if it's a string/list of tuples.
# Let's assume `history[-1][0]` is the raw `user_multimodal_input_dict`
# This means the `user` function must append it like: `history.append([user_multimodal_input_dict, None])`
# And the chatbot will display `str(user_multimodal_input_dict)`.
# This is what the current `user` function does.
user_input_data = history[-1][0] # This should be the dict from MultimodalTextbox
text_input_for_agent = user_input_data.get("text", "")
# Files from MultimodalTextbox are temp file paths
image_file_paths_for_agent = [f.name for f in user_input_data.get("files", []) if hasattr(f, 'name')]
history[-1][1] = "" # Initialize assistant's part for streaming
# Buffer for current text stream from agent
# Handles both pure text deltas and text content from gr.ChatMessage
current_text_for_turn = ""
for item in respond(
message_text=text_input_for_agent,
image_file_paths=image_file_paths_for_agent,
gradio_history=history[:-1], # Pass previous turns for context if agent uses it
system_message_for_agent=system_msg,
max_tokens=max_tokens, temperature=temperature, top_p=top_p,
frequency_penalty=freq_penalty, seed=seed,
provider_for_agent_llm=provider, api_key_for_agent_llm=api_key,
model_id_for_agent_llm=custom_model,
model_search_term=search_term, # unused
selected_model_for_agent_llm=selected_model
):
if isinstance(item, str): # LLM text delta from agent's thought or textual answer
current_text_for_turn = item
history[-1][1] = current_text_for_turn
elif isinstance(item, gr.ChatMessage):
# This is a structured step (thought, tool output, image, etc.)
# We need to append this to the history as a new message or part of current message.
# For simplicity, let's append its string content to the current turn's assistant message.
# If it's an image/file, we'll represent it as a markdown link.
if isinstance(item.content, str):
current_text_for_turn = item.content # Replace if it's a full message
elif isinstance(item.content, dict) and "path" in item.content:
# This is typically an image or audio file
file_path = item.content["path"]
# We need to make this file accessible to Gradio if it's temporary from agent
# For now, just put a placeholder.
# If it's an output from a tool, the path might be relative to where smolagents saves it.
# Gradio needs an absolute path or a URL.
# A common pattern is to copy temp files to a static dir served by Gradio or use gr.File.
# For now, let's assume Gradio can handle local paths if they are in a folder it knows.
# We'll display it as a tuple for Gradio Chatbot.
# This means history[-1][1] needs to become a list.
# If current_text_for_turn is not empty, make history[-1][1] a list
if current_text_for_turn and not isinstance(history[-1][1], list):
history[-1][1] = [current_text_for_turn]
elif not current_text_for_turn and not isinstance(history[-1][1], list):
history[-1][1] = []
alt_text = item.metadata.get("title", os.path.basename(file_path)) if item.metadata else os.path.basename(file_path)
# Add as new component to the list for current assistant message
if isinstance(history[-1][1], list):
history[-1][1].append((file_path, alt_text))
else: # Should have been made a list above
history[-1][1] = [(file_path, alt_text)]
current_text_for_turn = "" # Reset text buffer after a file
# If it's not a delta, but a full message, replace the current text
if not isinstance(history[-1][1], list): # if it hasn't become a list due to file
history[-1][1] = current_text_for_turn
yield history
# Event handlers
# `msg.submit`'s first argument is the function to call.
# Its `inputs` are the Gradio components whose values are passed to the function.
# Its `outputs` are the Gradio components that are updated by the function's return value.
# The `user` function now appends the raw dict from MultimodalTextbox to history.
# The `bot` function takes this history.
# When msg is submitted:
# 1. Call `user` to update history with user's input. Output is `chatbot`.
# 2. Then call `bot` with the updated history. Output is `chatbot`.
# 3. Then clear `msg`
msg.submit(
user,
[msg, chatbot],
[chatbot], # `user` returns the new history, updating the chatbot display
queue=False
).then(
bot,
[chatbot, system_message_box, max_tokens_slider, temperature_slider, top_p_slider,
frequency_penalty_slider, seed_slider, provider_radio, byok_textbox, custom_model_box,
model_search_box, featured_model_radio],
[chatbot] # `bot` yields history updates, streaming to chatbot
).then(
lambda: {"text": "", "files": []}, # Clear MultimodalTextbox
None,
[msg]
)
model_search_box.change(fn=filter_models, inputs=model_search_box, outputs=featured_model_radio)
featured_model_radio.change(fn=set_custom_model_from_radio, inputs=featured_model_radio, outputs=custom_model_box)
byok_textbox.change(fn=validate_provider, inputs=[byok_textbox, provider_radio], outputs=provider_radio)
provider_radio.change(fn=validate_provider, inputs=[byok_textbox, provider_radio], outputs=provider_radio)
print("Gradio interface initialized.")
if __name__ == "__main__":
print("Launching the demo application.")
demo.launch(show_api=False) # show_api=False for cleaner launch, True for API docs