Nymbo's picture
Update app.py
dc27384 verified
raw
history blame
34 kB
import gradio as gr
from huggingface_hub import InferenceClient
import os
import json
import base64
from PIL import Image
import io
import requests # Retained, though not directly used in the core logic shown for modification
from smolagents.mcp_client import MCPClient
ACCESS_TOKEN = os.getenv("HF_TOKEN")
print("Access token loaded.")
# Function to encode image to base64
def encode_image(image_path):
if not image_path:
print("No image path provided")
return None
try:
print(f"Encoding image from path: {image_path}")
# If it's already a PIL Image
if isinstance(image_path, Image.Image):
image = image_path
else:
# Try to open the image file
image = Image.open(image_path)
# Convert to RGB if image has an alpha channel (RGBA)
if image.mode == 'RGBA':
image = image.convert('RGB')
# Encode to base64
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
print("Image encoded successfully")
return img_str
except Exception as e:
print(f"Error encoding image: {e}")
return None
# Dictionary to store active MCP connections
mcp_connections = {}
def connect_to_mcp_server(server_url, server_name=None):
"""Connect to an MCP server and return available tools"""
if not server_url:
return None, "No server URL provided"
try:
# Create an MCP client and connect to the server
client = MCPClient({"url": server_url})
# Get available tools
tools = client.get_tools()
# Store the connection for later use
name = server_name or f"Server_{len(mcp_connections)}_{base64.urlsafe_b64encode(os.urandom(3)).decode()}" # Ensure unique name
mcp_connections[name] = {"client": client, "tools": tools, "url": server_url}
return name, f"Successfully connected to {name} with {len(tools)} available tools"
except Exception as e:
print(f"Error connecting to MCP server: {e}")
return None, f"Error connecting to MCP server: {str(e)}"
def list_mcp_tools(server_name):
"""List available tools for a connected MCP server"""
if server_name not in mcp_connections:
return "Server not connected"
tools = mcp_connections[server_name]["tools"]
tool_info = []
for tool in tools:
tool_info.append(f"- {tool.name}: {tool.description}")
if not tool_info:
return "No tools available for this server"
return "\n".join(tool_info)
def call_mcp_tool(server_name, tool_name, **kwargs):
"""Call a specific tool from an MCP server"""
if server_name not in mcp_connections:
return f"Server '{server_name}' not connected"
client = mcp_connections[server_name]["client"]
tools = mcp_connections[server_name]["tools"]
# Find the requested tool
tool = next((t for t in tools if t.name == tool_name), None)
if not tool:
return f"Tool '{tool_name}' not found on server '{server_name}'"
try:
# Call the tool with provided arguments
# The mcp_client's call_tool is expected to return the direct result from the tool
result = client.call_tool(tool_name, kwargs)
# The result here could be a string (e.g. base64 audio), a dict, or other types
# depending on the MCP tool. The `respond` function will handle formatting.
return result
except Exception as e:
print(f"Error calling MCP tool: {e}")
return f"Error calling MCP tool: {str(e)}"
def analyze_message_for_tool_call(message, active_mcp_servers, client_for_llm, model_to_use, system_message_for_llm):
"""Analyze a message to determine if an MCP tool should be called"""
# Skip analysis if message is empty
if not message or not message.strip():
return None, None
# Get information about available tools
tool_info = []
if active_mcp_servers:
for server_name in active_mcp_servers:
if server_name in mcp_connections:
server_tools = mcp_connections[server_name]["tools"]
for tool in server_tools:
tool_info.append({
"server_name": server_name,
"tool_name": tool.name,
"description": tool.description
})
if not tool_info:
return None, None
# Create a structured query for the LLM to analyze if a tool call is needed
tools_desc = []
for info in tool_info:
tools_desc.append(f"{info['server_name']}.{info['tool_name']}: {info['description']}")
tools_string = "\n".join(tools_desc)
# Updated prompt to guide LLM for TTS tool that returns base64
analysis_system_prompt = f"""You are an assistant that helps determine if a user message requires using an external tool.
Available tools:
{tools_string}
Your job is to:
1. Analyze the user's message.
2. Determine if they're asking to use one of the tools.
3. If yes, respond ONLY with a JSON object with "server_name", "tool_name", and "parameters".
4. If no, respond ONLY with the exact string "NO_TOOL_NEEDED".
Example 1 (for TTS that returns base64 audio):
User: "Please turn this text into speech: Hello world"
Response: {{"server_name": "kokoroTTS", "tool_name": "text_to_audio_b64", "parameters": {{"text": "Hello world", "speed": 1.0}}}}
Example 2 (for TTS with different speed):
User: "Read 'This is faster' at speed 1.5"
Response: {{"server_name": "kokoroTTS", "tool_name": "text_to_audio_b64", "parameters": {{"text": "This is faster", "speed": 1.5}}}}
Example 3 (general, non-tool):
User: "What is the capital of France?"
Response: NO_TOOL_NEEDED"""
try:
# Call the LLM to analyze the message
response = client_for_llm.chat_completion(
model=model_to_use,
messages=[
{"role": "system", "content": analysis_system_prompt},
{"role": "user", "content": message}
],
temperature=0.1, # Low temperature for deterministic tool selection
max_tokens=300
)
analysis = response.choices[0].message.content.strip()
print(f"Tool analysis raw response: '{analysis}'")
if analysis == "NO_TOOL_NEEDED":
return None, None
# Try to parse JSON directly from the response
try:
tool_call = json.loads(analysis)
return tool_call.get("server_name"), {
"tool_name": tool_call.get("tool_name"),
"parameters": tool_call.get("parameters", {})
}
except json.JSONDecodeError:
print(f"Failed to parse tool call JSON directly from: {analysis}")
# Fallback to extracting JSON if not a direct JSON response
json_start = analysis.find("{")
json_end = analysis.rfind("}") + 1
if json_start != -1 and json_end != 0 and json_end > json_start:
json_str = analysis[json_start:json_end]
try:
tool_call = json.loads(json_str)
return tool_call.get("server_name"), {
"tool_name": tool_call.get("tool_name"),
"parameters": tool_call.get("parameters", {})
}
except json.JSONDecodeError:
print(f"Failed to parse extracted tool call JSON: {json_str}")
return None, None
else:
print(f"No JSON object found in analysis: {analysis}")
return None, None
except Exception as e:
print(f"Error analyzing message for tool calls: {str(e)}")
return None, None
def respond(
message,
image_files,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
frequency_penalty,
seed,
provider,
custom_api_key,
custom_model,
model_search_term,
selected_model,
mcp_enabled=False,
active_mcp_servers=None,
mcp_interaction_mode="Natural Language"
):
print(f"Received message: {message}")
print(f"Received {len(image_files) if image_files else 0} images")
# print(f"History: {history}") # Can be very verbose
print(f"System message: {system_message}")
print(f"Max tokens: {max_tokens}, Temperature: {temperature}, Top-P: {top_p}")
print(f"Frequency Penalty: {frequency_penalty}, Seed: {seed}")
print(f"Selected provider: {provider}")
print(f"Custom API Key provided: {bool(custom_api_key.strip())}")
print(f"Selected model (custom_model): {custom_model}")
print(f"Model search term: {model_search_term}")
print(f"Selected model from radio: {selected_model}")
print(f"MCP enabled: {mcp_enabled}")
print(f"Active MCP servers: {active_mcp_servers}")
print(f"MCP interaction mode: {mcp_interaction_mode}")
token_to_use = custom_api_key if custom_api_key.strip() != "" else ACCESS_TOKEN
if custom_api_key.strip() != "":
print("USING CUSTOM API KEY: BYOK token provided by user is being used for authentication")
else:
print("USING DEFAULT API KEY: Environment variable HF_TOKEN is being used for authentication")
client_for_llm = InferenceClient(token=token_to_use, provider=provider)
print(f"Hugging Face Inference Client initialized with {provider} provider.")
if seed == -1:
seed = None
model_to_use = custom_model.strip() if custom_model.strip() != "" else selected_model
print(f"Model selected for inference: {model_to_use}")
if mcp_enabled and message:
if message.startswith("/mcp"):
command_parts = message.split(" ", 3)
if len(command_parts) < 3:
yield "Invalid MCP command. Format: /mcp <server_name> <tool_name> [arguments_json]"
return
_, server_name, tool_name = command_parts[:3]
args_json_str = "{}" if len(command_parts) < 4 else command_parts[3]
try:
args_dict = json.loads(args_json_str)
result = call_mcp_tool(server_name, tool_name, **args_dict)
if "audio" in tool_name.lower() and "b64" in tool_name.lower() and isinstance(result, str):
audio_html = f'<audio controls src="data:audio/wav;base64,{result}"></audio>'
yield f"Executed {tool_name} from {server_name}.\n\nResult:\n{audio_html}"
elif isinstance(result, dict):
yield json.dumps(result, indent=2)
else:
yield str(result)
return # MCP command handled, exit
except json.JSONDecodeError:
yield f"Invalid JSON arguments: {args_json_str}"
return
except Exception as e:
yield f"Error executing MCP command: {str(e)}"
return
elif mcp_interaction_mode == "Natural Language" and active_mcp_servers:
server_name, tool_info = analyze_message_for_tool_call(
message,
active_mcp_servers,
client_for_llm,
model_to_use,
system_message # Original system message for context, LLM uses its own for analysis
)
if server_name and tool_info and tool_info.get("tool_name"):
try:
print(f"Calling tool via natural language: {server_name}.{tool_info['tool_name']} with parameters: {tool_info.get('parameters', {})}")
result = call_mcp_tool(server_name, tool_info['tool_name'], **tool_info.get('parameters', {}))
tool_display_name = tool_info['tool_name']
if "audio" in tool_display_name.lower() and "b64" in tool_display_name.lower() and isinstance(result, str) and len(result) > 100: # Heuristic for base64 audio
audio_html = f'<audio controls src="data:audio/wav;base64,{result}"></audio>'
yield f"I used the {tool_display_name} tool from {server_name} with your request.\n\nResult:\n{audio_html}"
elif isinstance(result, dict):
result_str = json.dumps(result, indent=2)
yield f"I used the {tool_display_name} tool from {server_name} with your request.\n\nResult:\n{result_str}"
else:
result_str = str(result)
yield f"I used the {tool_display_name} tool from {server_name} with your request.\n\nResult:\n{result_str}"
return # MCP tool call handled via natural language
except Exception as e:
print(f"Error executing MCP tool via natural language: {str(e)}")
yield f"I tried to use a tool but encountered an error: {str(e)}. I will try to respond without it."
# Fall through to normal LLM response if tool call fails
user_content = []
if message and message.strip():
user_content.append({"type": "text", "text": message})
if image_files and len(image_files) > 0:
for img_path in image_files:
if img_path is not None:
try:
encoded_image = encode_image(img_path)
if encoded_image:
user_content.append({
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}
})
except Exception as e:
print(f"Error encoding image for user content: {e}")
if not user_content: # If message was empty and no images, or only MCP command handled
if not message.startswith("/mcp"): # Avoid yielding empty if it was an MCP command
yield "" # Or handle appropriately, maybe return if no content
return
augmented_system_message = system_message
if mcp_enabled and active_mcp_servers:
tool_desc_list = []
for server_name_active in active_mcp_servers:
if server_name_active in mcp_connections:
# Get tools for this specific server
# Assuming list_mcp_tools returns a string like "- tool1: desc1\n- tool2: desc2"
server_tools_str = list_mcp_tools(server_name_active)
if server_tools_str != "Server not connected" and server_tools_str != "No tools available for this server":
for line in server_tools_str.split('\n'):
if line.startswith("- "):
tool_desc_list.append(f"{server_name_active}.{line[2:]}") # e.g., kokoroTTS.text_to_audio_b64: Convert text...
if tool_desc_list:
mcp_tools_description_for_llm = "\n".join(tool_desc_list)
# This informs the main LLM about available tools for general conversation,
# distinct from the specialized analyzer LLM.
# The main LLM doesn't call tools directly but can use this info to guide the user.
if mcp_interaction_mode == "Command Mode":
augmented_system_message += f"\n\nYou have access to the following MCP tools which the user can invoke:\n{mcp_tools_description_for_llm}\n\nTo use these tools, the user can type a command in the format: /mcp <server_name> <tool_name> <arguments_json>"
else: # Natural Language
augmented_system_message += f"\n\nYou have access to the following MCP tools. The system will try to use them automatically if the user's request matches their capability:\n{mcp_tools_description_for_llm}\n\nIf the user asks to do something a tool can do, the system will attempt to use it. For example, if a 'text_to_audio_b64' tool is available, and the user says 'read this text aloud', the system will try to use that tool."
messages_for_llm = [{"role": "system", "content": augmented_system_message}]
print("Initial messages array constructed.")
for hist_user, hist_assistant in history:
# hist_user can be complex if it included images from MultimodalTextbox
# We need to reconstruct it properly for the LLM
current_hist_user_content = []
if isinstance(hist_user, dict) and 'text' in hist_user and 'files' in hist_user: # From MultimodalTextbox
if hist_user['text'] and hist_user['text'].strip():
current_hist_user_content.append({"type": "text", "text": hist_user['text']})
if hist_user['files']:
for img_file_path in hist_user['files']:
encoded_img = encode_image(img_file_path)
if encoded_img:
current_hist_user_content.append({
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{encoded_img}"}
})
elif isinstance(hist_user, str): # Simple text history
current_hist_user_content.append({"type": "text", "text": hist_user})
if current_hist_user_content:
messages_for_llm.append({"role": "user", "content": current_hist_user_content})
if hist_assistant: # Assistant message is always text
# Check if assistant message was an HTML audio tag, if so, send a placeholder to LLM
if "<audio controls src=" in hist_assistant:
messages_for_llm.append({"role": "assistant", "content": "[Audio was played in response to the previous message]"})
else:
messages_for_llm.append({"role": "assistant", "content": hist_assistant})
messages_for_llm.append({"role": "user", "content": user_content})
print(f"Latest user message appended (content type: {type(user_content)})")
# print(f"Messages for LLM: {json.dumps(messages_for_llm, indent=2)}") # Very verbose
response_text = ""
print(f"Sending request to {provider} provider for general response.")
parameters = {
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"frequency_penalty": frequency_penalty,
}
if seed is not None:
parameters["seed"] = seed
try:
stream = client_for_llm.chat_completion(
model=model_to_use,
messages=messages_for_llm,
stream=True,
**parameters
)
print("Streaming LLM response: ", end="", flush=True)
for chunk in stream:
if hasattr(chunk, 'choices') and len(chunk.choices) > 0:
if hasattr(chunk.choices[0], 'delta') and hasattr(chunk.choices[0].delta, 'content'):
token_text = chunk.choices[0].delta.content
if token_text:
print(token_text, end="", flush=True)
response_text += token_text
yield response_text
print() # Newline after streaming
except Exception as e:
print(f"Error during LLM inference: {e}")
response_text += f"\nError during LLM response generation: {str(e)}"
yield response_text
print("Completed LLM response generation.")
# GRADIO UI
with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
chatbot = gr.Chatbot(
height=600,
show_copy_button=True,
placeholder="Select a model and begin chatting. Now supports multiple inference providers, multimodal inputs, and MCP tools",
layout="panel",
show_label=False,
render=False # Delay rendering
)
print("Chatbot interface created.")
with gr.Row():
msg = gr.MultimodalTextbox(
placeholder="Type a message or upload images...",
show_label=False,
container=True, # Ensure it's a container for proper layout
scale=12,
file_types=["image"],
file_count="multiple",
sources=["upload"],
render=False # Delay rendering
)
# Render chatbot and message box after defining them
chatbot.render()
msg.render()
with gr.Accordion("Settings", open=False):
system_message_box = gr.Textbox(
value="You are a helpful AI assistant that can understand images and text.",
placeholder="You are a helpful assistant.",
label="System Prompt"
)
with gr.Row():
with gr.Column():
max_tokens_slider = gr.Slider(minimum=1, maximum=4096, value=512, step=1, label="Max tokens")
temperature_slider = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-P")
with gr.Column():
frequency_penalty_slider = gr.Slider(minimum=-2.0, maximum=2.0, value=0.0, step=0.1, label="Frequency Penalty")
seed_slider = gr.Slider(minimum=-1, maximum=65535, value=-1, step=1, label="Seed (-1 for random)")
providers_list = ["hf-inference", "cerebras", "together", "sambanova", "novita", "cohere", "fireworks-ai", "hyperbolic", "nebius"]
provider_radio = gr.Radio(choices=providers_list, value="hf-inference", label="Inference Provider")
byok_textbox = gr.Textbox(value="", label="BYOK (Bring Your Own Key)", info="Enter a custom Hugging Face API key here. If empty, only 'hf-inference' provider can be used with the default token.", placeholder="Enter your Hugging Face API token", type="password")
custom_model_box = gr.Textbox(value="", label="Custom Model", info="(Optional) Provide a Hugging Face model path. Overrides selected featured model.", placeholder="meta-llama/Llama-3.1-70B-Instruct")
model_search_box = gr.Textbox(label="Filter Models", placeholder="Search for a featured model...", lines=1)
models_list = [
"meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.1-405B-Instruct", "meta-llama/Llama-3.1-70B-Instruct", "meta-llama/Llama-3.1-8B-Instruct",
"meta-llama/Llama-3-70B-Instruct", "meta-llama/Llama-3-8B-Instruct",
"NousResearch/Hermes-3-Llama-3.1-8B", "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
"mistralai/Mistral-Nemo-Instruct-2407", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.3", "mistralai/Mistral-7B-Instruct-v0.2",
"Qwen/Qwen2.5-72B-Instruct", "Qwen/Qwen2-72B-Instruct", "Qwen/Qwen2-57B-A14B-Instruct", "Qwen/Qwen1.5-110B-Chat",
"microsoft/Phi-3-medium-128k-instruct", "microsoft/Phi-3-mini-128k-instruct", "microsoft/Phi-3-small-128k-instruct",
"google/gemma-2-27b-it", "google/gemma-2-9b-it",
"CohereForAI/c4ai-command-r-plus",
"deepseek-ai/DeepSeek-V2-Chat",
"Snowflake/snowflake-arctic-instruct"
] # Keeping your original list, just formatted for readability
featured_model_radio = gr.Radio(label="Select a featured model", choices=models_list, value="meta-llama/Llama-3.2-11B-Vision-Instruct", interactive=True)
gr.Markdown("[View all Text-to-Text models](https://huggingface.co/models?pipeline_tag=text-generation&sort=trending) | [View all multimodal models](https://huggingface.co/models?pipeline_tag=image-text-to-text&sort=trending)")
with gr.Accordion("MCP Settings", open=False):
mcp_enabled_checkbox = gr.Checkbox(label="Enable MCP Support", value=False, info="Enable Model Context Protocol support for external tools")
with gr.Row():
mcp_server_url = gr.Textbox(label="MCP Server URL", placeholder="https://your-mcp-server.hf.space/gradio_api/mcp/sse")
mcp_server_name = gr.Textbox(label="Server Name (Optional)", placeholder="e.g., kokoroTTS")
mcp_connect_button = gr.Button("Connect to MCP Server")
mcp_status = gr.Textbox(label="MCP Connection Status", placeholder="No MCP servers connected", interactive=False)
active_mcp_servers = gr.Dropdown(label="Active MCP Servers for Chat", choices=[], multiselect=True, info="Select which connected MCP servers to use")
mcp_mode = gr.Radio(label="MCP Interaction Mode", choices=["Natural Language", "Command Mode"], value="Natural Language", info="How to trigger MCP tools")
gr.Markdown("""
### MCP Interaction Modes
**Natural Language**: Describe what you want. E.g., "Convert 'Hello' to speech".
**Command Mode**: Use `/mcp <server_name> <tool_name> {"param": "value"}`. E.g., `/mcp kokoroTTS text_to_audio_b64 {"text": "Hello world"}`.
""")
chat_history_state = gr.State([]) # To store the actual history for the LLM
def filter_models_choices(search_term):
print(f"Filtering models with search term: {search_term}")
if not search_term: return gr.update(choices=models_list)
filtered = [m for m in models_list if search_term.lower() in m.lower()]
print(f"Filtered models: {filtered}")
return gr.update(choices=filtered if filtered else models_list, value=featured_model_radio.value if featured_model_radio.value in filtered else (filtered[0] if filtered else models_list[0]))
def update_custom_model_from_radio(selected_featured_model):
print(f"Featured model selected: {selected_featured_model}")
# This function now updates the custom_model_box.
# If you want the radio selection to BE the model_to_use unless custom_model_box has text,
# then custom_model_box should be cleared or its value used as override.
# For now, let's assume custom_model_box is an override.
# If you want the radio to directly feed into the selected_model parameter for respond(),
# then this function might not be needed or custom_model_box should be used as an override.
return selected_featured_model # This updates the custom_model_box with the radio selection.
def handle_connect_mcp_server(url, name_suggestion):
actual_name, status_msg = connect_to_mcp_server(url, name_suggestion)
all_server_names = list(mcp_connections.keys())
# Keep existing selections if possible
current_selection = active_mcp_servers.value if active_mcp_servers.value else []
new_selection = [s for s in current_selection if s in all_server_names]
if actual_name and actual_name not in new_selection : # Auto-select newly connected server
new_selection.append(actual_name)
return status_msg, gr.update(choices=all_server_names, value=new_selection)
# This function is called when the user submits a message.
# It updates the visual chatbot history and prepares the state for the bot.
def handle_user_message(user_input_dict, current_chat_history_state):
text_content = user_input_dict.get("text", "").strip()
files = user_input_dict.get("files", []) # List of file paths
# Add to visual history (chatbot component)
visual_history_additions = []
# Store for LLM (chat_history_state)
# We store the raw dict from MultimodalTextbox for user messages
# to correctly reconstruct for the LLM later.
current_chat_history_state.append([user_input_dict, None])
# For visual chatbot, create separate entries for text and images
if text_content:
visual_history_additions.append([text_content, None])
if files:
for file_path in files:
visual_history_additions.append([ (file_path,), None]) # Gradio Chatbot expects tuple for files
return visual_history_additions, current_chat_history_state
# This function is called after user message is processed.
# It calls the LLM and streams the response.
def handle_bot_response(
current_chat_history_state, # This is the state with the latest user message
sys_msg, max_tok, temp, top_p_val, freq_pen, seed_val, prov, api_key_val, cust_model,
search, selected_feat_model, mcp_on, active_servs, mcp_interact_mode
):
if not current_chat_history_state or current_chat_history_state[-1][1] is not None:
# User message not yet added or bot already responded
yield current_chat_history_state # Or some empty update
return
# The user message is the first element of the last item in chat_history_state
# It's a dict: {'text': '...', 'files': ['path1', ...]}
user_message_dict = current_chat_history_state[-1][0]
text_from_user_dict = user_message_dict.get("text", "")
files_from_user_dict = user_message_dict.get("files", [])
# History for LLM should exclude the current un-responded user message
history_for_llm = current_chat_history_state[:-1]
# Stream response from LLM
full_response = ""
for R in respond(
message=text_from_user_dict,
image_files=files_from_user_dict,
history=history_for_llm, # Pass history BEFORE current turn
system_message=sys_msg,
max_tokens=max_tok,
temperature=temp,
top_p=top_p_val,
frequency_penalty=freq_pen,
seed=seed_val,
provider=prov,
custom_api_key=api_key_val,
custom_model=cust_model,
model_search_term=search, # This might be redundant if featured_model_radio directly updates custom_model_box
selected_model=selected_feat_model, # This is the value from the radio
mcp_enabled=mcp_on,
active_mcp_servers=active_servs,
mcp_interaction_mode=mcp_interact_mode
):
full_response = R
# Update the last item in chat_history_state with bot's response
current_chat_history_state[-1][1] = full_response
# Update visual chatbot
# Need to reconstruct visual history from state
visual_history_update = []
for user_turn, bot_turn in current_chat_history_state:
# User turn processing
user_text_viz = user_turn.get("text", "")
user_files_viz = user_turn.get("files", [])
if user_text_viz:
visual_history_update.append([user_text_viz, None if bot_turn is None and user_turn == current_chat_history_state[-1][0] else bot_turn]) # Add text part
for f_path in user_files_viz:
visual_history_update.append([(f_path,), None if bot_turn is None and user_turn == current_chat_history_state[-1][0] else bot_turn]) # Add image part
# Bot turn processing if user turn was only text and no files
if not user_text_viz and not user_files_viz and user_text_viz == "" : # Should not happen with current logic
visual_history_update.append(["", bot_turn])
elif not user_files_viz and user_text_viz and bot_turn is not None and visual_history_update[-1][0] == user_text_viz :
visual_history_update[-1][1] = bot_turn # Assign bot response to the text part
yield visual_history_update, current_chat_history_state
# Event handlers
msg.submit(
handle_user_message,
[msg, chat_history_state],
[chatbot, chat_history_state], # Update visual chatbot and state
queue=True # Use queue for streaming
).then(
handle_bot_response,
[chat_history_state, system_message_box, max_tokens_slider, temperature_slider, top_p_slider,
frequency_penalty_slider, seed_slider, provider_radio, byok_textbox, custom_model_box,
model_search_box, featured_model_radio, mcp_enabled_checkbox, active_mcp_servers, mcp_mode],
[chatbot, chat_history_state] # Update visual chatbot and state again with bot response
).then(
lambda: gr.update(value={"text": "", "files": []}), # Clear MultimodalTextbox
None,
[msg],
queue=False # No queue for simple UI update
)
mcp_connect_button.click(
handle_connect_mcp_server,
[mcp_server_url, mcp_server_name],
[mcp_status, active_mcp_servers]
)
model_search_box.change(fn=filter_models_choices, inputs=model_search_box, outputs=featured_model_radio)
# Let radio button directly be the selected_model, custom_model_box is an override
# featured_model_radio.change(fn=update_custom_model_from_radio, inputs=featured_model_radio, outputs=custom_model_box)
def validate_provider_choice(api_key_val, current_provider_val):
if not api_key_val.strip() and current_provider_val != "hf-inference":
gr.Info("No custom API key provided. Only 'hf-inference' provider can be used. Switching to 'hf-inference'.")
return gr.update(value="hf-inference")
return gr.update() # No change needed if valid or key provided
byok_textbox.change(fn=validate_provider_choice, inputs=[byok_textbox, provider_radio], outputs=provider_radio)
provider_radio.change(fn=validate_provider_choice, inputs=[byok_textbox, provider_radio], outputs=provider_radio)
print("Gradio interface initialized.")
if __name__ == "__main__":
print("Launching the demo application.")
demo.queue().launch(show_api=False, mcp_server=False, share=os.environ.get("GRADIO_SHARE", "").lower() == "true")