Spaces:
Running
Running
import gradio as gr | |
from huggingface_hub import InferenceClient | |
import os | |
import json | |
import base64 | |
from PIL import Image | |
import io | |
import requests # Retained, though not directly used in the core logic shown for modification | |
from smolagents.mcp_client import MCPClient | |
ACCESS_TOKEN = os.getenv("HF_TOKEN") | |
print("Access token loaded.") | |
# Function to encode image to base64 | |
def encode_image(image_path): | |
if not image_path: | |
print("No image path provided") | |
return None | |
try: | |
print(f"Encoding image from path: {image_path}") | |
# If it's already a PIL Image | |
if isinstance(image_path, Image.Image): | |
image = image_path | |
else: | |
# Try to open the image file | |
image = Image.open(image_path) | |
# Convert to RGB if image has an alpha channel (RGBA) | |
if image.mode == 'RGBA': | |
image = image.convert('RGB') | |
# Encode to base64 | |
buffered = io.BytesIO() | |
image.save(buffered, format="JPEG") | |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
print("Image encoded successfully") | |
return img_str | |
except Exception as e: | |
print(f"Error encoding image: {e}") | |
return None | |
# Dictionary to store active MCP connections | |
mcp_connections = {} | |
def connect_to_mcp_server(server_url, server_name=None): | |
"""Connect to an MCP server and return available tools""" | |
if not server_url: | |
return None, "No server URL provided" | |
try: | |
# Create an MCP client and connect to the server | |
client = MCPClient({"url": server_url}) | |
# Get available tools | |
tools = client.get_tools() | |
# Store the connection for later use | |
name = server_name or f"Server_{len(mcp_connections)}_{base64.urlsafe_b64encode(os.urandom(3)).decode()}" # Ensure unique name | |
mcp_connections[name] = {"client": client, "tools": tools, "url": server_url} | |
return name, f"Successfully connected to {name} with {len(tools)} available tools" | |
except Exception as e: | |
print(f"Error connecting to MCP server: {e}") | |
return None, f"Error connecting to MCP server: {str(e)}" | |
def list_mcp_tools(server_name): | |
"""List available tools for a connected MCP server""" | |
if server_name not in mcp_connections: | |
return "Server not connected" | |
tools = mcp_connections[server_name]["tools"] | |
tool_info = [] | |
for tool in tools: | |
tool_info.append(f"- {tool.name}: {tool.description}") | |
if not tool_info: | |
return "No tools available for this server" | |
return "\n".join(tool_info) | |
def call_mcp_tool(server_name, tool_name, **kwargs): | |
"""Call a specific tool from an MCP server""" | |
if server_name not in mcp_connections: | |
return f"Server '{server_name}' not connected" | |
client = mcp_connections[server_name]["client"] | |
tools = mcp_connections[server_name]["tools"] | |
# Find the requested tool | |
tool = next((t for t in tools if t.name == tool_name), None) | |
if not tool: | |
return f"Tool '{tool_name}' not found on server '{server_name}'" | |
try: | |
# Call the tool with provided arguments | |
# The mcp_client's call_tool is expected to return the direct result from the tool | |
result = client.call_tool(tool_name, kwargs) | |
# The result here could be a string (e.g. base64 audio), a dict, or other types | |
# depending on the MCP tool. The `respond` function will handle formatting. | |
return result | |
except Exception as e: | |
print(f"Error calling MCP tool: {e}") | |
return f"Error calling MCP tool: {str(e)}" | |
def analyze_message_for_tool_call(message, active_mcp_servers, client_for_llm, model_to_use, system_message_for_llm): | |
"""Analyze a message to determine if an MCP tool should be called""" | |
# Skip analysis if message is empty | |
if not message or not message.strip(): | |
return None, None | |
# Get information about available tools | |
tool_info = [] | |
if active_mcp_servers: | |
for server_name in active_mcp_servers: | |
if server_name in mcp_connections: | |
server_tools = mcp_connections[server_name]["tools"] | |
for tool in server_tools: | |
tool_info.append({ | |
"server_name": server_name, | |
"tool_name": tool.name, | |
"description": tool.description | |
}) | |
if not tool_info: | |
return None, None | |
# Create a structured query for the LLM to analyze if a tool call is needed | |
tools_desc = [] | |
for info in tool_info: | |
tools_desc.append(f"{info['server_name']}.{info['tool_name']}: {info['description']}") | |
tools_string = "\n".join(tools_desc) | |
# Updated prompt to guide LLM for TTS tool that returns base64 | |
analysis_system_prompt = f"""You are an assistant that helps determine if a user message requires using an external tool. | |
Available tools: | |
{tools_string} | |
Your job is to: | |
1. Analyze the user's message. | |
2. Determine if they're asking to use one of the tools. | |
3. If yes, respond ONLY with a JSON object with "server_name", "tool_name", and "parameters". | |
4. If no, respond ONLY with the exact string "NO_TOOL_NEEDED". | |
Example 1 (for TTS that returns base64 audio): | |
User: "Please turn this text into speech: Hello world" | |
Response: {{"server_name": "kokoroTTS", "tool_name": "text_to_audio_b64", "parameters": {{"text": "Hello world", "speed": 1.0}}}} | |
Example 2 (for TTS with different speed): | |
User: "Read 'This is faster' at speed 1.5" | |
Response: {{"server_name": "kokoroTTS", "tool_name": "text_to_audio_b64", "parameters": {{"text": "This is faster", "speed": 1.5}}}} | |
Example 3 (general, non-tool): | |
User: "What is the capital of France?" | |
Response: NO_TOOL_NEEDED""" | |
try: | |
# Call the LLM to analyze the message | |
response = client_for_llm.chat_completion( | |
model=model_to_use, | |
messages=[ | |
{"role": "system", "content": analysis_system_prompt}, | |
{"role": "user", "content": message} | |
], | |
temperature=0.1, # Low temperature for deterministic tool selection | |
max_tokens=300 | |
) | |
analysis = response.choices[0].message.content.strip() | |
print(f"Tool analysis raw response: '{analysis}'") | |
if analysis == "NO_TOOL_NEEDED": | |
return None, None | |
# Try to parse JSON directly from the response | |
try: | |
tool_call = json.loads(analysis) | |
return tool_call.get("server_name"), { | |
"tool_name": tool_call.get("tool_name"), | |
"parameters": tool_call.get("parameters", {}) | |
} | |
except json.JSONDecodeError: | |
print(f"Failed to parse tool call JSON directly from: {analysis}") | |
# Fallback to extracting JSON if not a direct JSON response | |
json_start = analysis.find("{") | |
json_end = analysis.rfind("}") + 1 | |
if json_start != -1 and json_end != 0 and json_end > json_start: | |
json_str = analysis[json_start:json_end] | |
try: | |
tool_call = json.loads(json_str) | |
return tool_call.get("server_name"), { | |
"tool_name": tool_call.get("tool_name"), | |
"parameters": tool_call.get("parameters", {}) | |
} | |
except json.JSONDecodeError: | |
print(f"Failed to parse extracted tool call JSON: {json_str}") | |
return None, None | |
else: | |
print(f"No JSON object found in analysis: {analysis}") | |
return None, None | |
except Exception as e: | |
print(f"Error analyzing message for tool calls: {str(e)}") | |
return None, None | |
def respond( | |
message, | |
image_files, | |
history: list[tuple[str, str]], | |
system_message, | |
max_tokens, | |
temperature, | |
top_p, | |
frequency_penalty, | |
seed, | |
provider, | |
custom_api_key, | |
custom_model, | |
model_search_term, | |
selected_model, | |
mcp_enabled=False, | |
active_mcp_servers=None, | |
mcp_interaction_mode="Natural Language" | |
): | |
print(f"Received message: {message}") | |
print(f"Received {len(image_files) if image_files else 0} images") | |
# print(f"History: {history}") # Can be very verbose | |
print(f"System message: {system_message}") | |
print(f"Max tokens: {max_tokens}, Temperature: {temperature}, Top-P: {top_p}") | |
print(f"Frequency Penalty: {frequency_penalty}, Seed: {seed}") | |
print(f"Selected provider: {provider}") | |
print(f"Custom API Key provided: {bool(custom_api_key.strip())}") | |
print(f"Selected model (custom_model): {custom_model}") | |
print(f"Model search term: {model_search_term}") | |
print(f"Selected model from radio: {selected_model}") | |
print(f"MCP enabled: {mcp_enabled}") | |
print(f"Active MCP servers: {active_mcp_servers}") | |
print(f"MCP interaction mode: {mcp_interaction_mode}") | |
token_to_use = custom_api_key if custom_api_key.strip() != "" else ACCESS_TOKEN | |
if custom_api_key.strip() != "": | |
print("USING CUSTOM API KEY: BYOK token provided by user is being used for authentication") | |
else: | |
print("USING DEFAULT API KEY: Environment variable HF_TOKEN is being used for authentication") | |
client_for_llm = InferenceClient(token=token_to_use, provider=provider) | |
print(f"Hugging Face Inference Client initialized with {provider} provider.") | |
if seed == -1: | |
seed = None | |
model_to_use = custom_model.strip() if custom_model.strip() != "" else selected_model | |
print(f"Model selected for inference: {model_to_use}") | |
if mcp_enabled and message: | |
if message.startswith("/mcp"): | |
command_parts = message.split(" ", 3) | |
if len(command_parts) < 3: | |
yield "Invalid MCP command. Format: /mcp <server_name> <tool_name> [arguments_json]" | |
return | |
_, server_name, tool_name = command_parts[:3] | |
args_json_str = "{}" if len(command_parts) < 4 else command_parts[3] | |
try: | |
args_dict = json.loads(args_json_str) | |
result = call_mcp_tool(server_name, tool_name, **args_dict) | |
if "audio" in tool_name.lower() and "b64" in tool_name.lower() and isinstance(result, str): | |
audio_html = f'<audio controls src="data:audio/wav;base64,{result}"></audio>' | |
yield f"Executed {tool_name} from {server_name}.\n\nResult:\n{audio_html}" | |
elif isinstance(result, dict): | |
yield json.dumps(result, indent=2) | |
else: | |
yield str(result) | |
return # MCP command handled, exit | |
except json.JSONDecodeError: | |
yield f"Invalid JSON arguments: {args_json_str}" | |
return | |
except Exception as e: | |
yield f"Error executing MCP command: {str(e)}" | |
return | |
elif mcp_interaction_mode == "Natural Language" and active_mcp_servers: | |
server_name, tool_info = analyze_message_for_tool_call( | |
message, | |
active_mcp_servers, | |
client_for_llm, | |
model_to_use, | |
system_message # Original system message for context, LLM uses its own for analysis | |
) | |
if server_name and tool_info and tool_info.get("tool_name"): | |
try: | |
print(f"Calling tool via natural language: {server_name}.{tool_info['tool_name']} with parameters: {tool_info.get('parameters', {})}") | |
result = call_mcp_tool(server_name, tool_info['tool_name'], **tool_info.get('parameters', {})) | |
tool_display_name = tool_info['tool_name'] | |
if "audio" in tool_display_name.lower() and "b64" in tool_display_name.lower() and isinstance(result, str) and len(result) > 100: # Heuristic for base64 audio | |
audio_html = f'<audio controls src="data:audio/wav;base64,{result}"></audio>' | |
yield f"I used the {tool_display_name} tool from {server_name} with your request.\n\nResult:\n{audio_html}" | |
elif isinstance(result, dict): | |
result_str = json.dumps(result, indent=2) | |
yield f"I used the {tool_display_name} tool from {server_name} with your request.\n\nResult:\n{result_str}" | |
else: | |
result_str = str(result) | |
yield f"I used the {tool_display_name} tool from {server_name} with your request.\n\nResult:\n{result_str}" | |
return # MCP tool call handled via natural language | |
except Exception as e: | |
print(f"Error executing MCP tool via natural language: {str(e)}") | |
yield f"I tried to use a tool but encountered an error: {str(e)}. I will try to respond without it." | |
# Fall through to normal LLM response if tool call fails | |
user_content = [] | |
if message and message.strip(): | |
user_content.append({"type": "text", "text": message}) | |
if image_files and len(image_files) > 0: | |
for img_path in image_files: | |
if img_path is not None: | |
try: | |
encoded_image = encode_image(img_path) | |
if encoded_image: | |
user_content.append({ | |
"type": "image_url", | |
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"} | |
}) | |
except Exception as e: | |
print(f"Error encoding image for user content: {e}") | |
if not user_content: # If message was empty and no images, or only MCP command handled | |
if not message.startswith("/mcp"): # Avoid yielding empty if it was an MCP command | |
yield "" # Or handle appropriately, maybe return if no content | |
return | |
augmented_system_message = system_message | |
if mcp_enabled and active_mcp_servers: | |
tool_desc_list = [] | |
for server_name_active in active_mcp_servers: | |
if server_name_active in mcp_connections: | |
# Get tools for this specific server | |
# Assuming list_mcp_tools returns a string like "- tool1: desc1\n- tool2: desc2" | |
server_tools_str = list_mcp_tools(server_name_active) | |
if server_tools_str != "Server not connected" and server_tools_str != "No tools available for this server": | |
for line in server_tools_str.split('\n'): | |
if line.startswith("- "): | |
tool_desc_list.append(f"{server_name_active}.{line[2:]}") # e.g., kokoroTTS.text_to_audio_b64: Convert text... | |
if tool_desc_list: | |
mcp_tools_description_for_llm = "\n".join(tool_desc_list) | |
# This informs the main LLM about available tools for general conversation, | |
# distinct from the specialized analyzer LLM. | |
# The main LLM doesn't call tools directly but can use this info to guide the user. | |
if mcp_interaction_mode == "Command Mode": | |
augmented_system_message += f"\n\nYou have access to the following MCP tools which the user can invoke:\n{mcp_tools_description_for_llm}\n\nTo use these tools, the user can type a command in the format: /mcp <server_name> <tool_name> <arguments_json>" | |
else: # Natural Language | |
augmented_system_message += f"\n\nYou have access to the following MCP tools. The system will try to use them automatically if the user's request matches their capability:\n{mcp_tools_description_for_llm}\n\nIf the user asks to do something a tool can do, the system will attempt to use it. For example, if a 'text_to_audio_b64' tool is available, and the user says 'read this text aloud', the system will try to use that tool." | |
messages_for_llm = [{"role": "system", "content": augmented_system_message}] | |
print("Initial messages array constructed.") | |
for hist_user, hist_assistant in history: | |
# hist_user can be complex if it included images from MultimodalTextbox | |
# We need to reconstruct it properly for the LLM | |
current_hist_user_content = [] | |
if isinstance(hist_user, dict) and 'text' in hist_user and 'files' in hist_user: # From MultimodalTextbox | |
if hist_user['text'] and hist_user['text'].strip(): | |
current_hist_user_content.append({"type": "text", "text": hist_user['text']}) | |
if hist_user['files']: | |
for img_file_path in hist_user['files']: | |
encoded_img = encode_image(img_file_path) | |
if encoded_img: | |
current_hist_user_content.append({ | |
"type": "image_url", | |
"image_url": {"url": f"data:image/jpeg;base64,{encoded_img}"} | |
}) | |
elif isinstance(hist_user, str): # Simple text history | |
current_hist_user_content.append({"type": "text", "text": hist_user}) | |
if current_hist_user_content: | |
messages_for_llm.append({"role": "user", "content": current_hist_user_content}) | |
if hist_assistant: # Assistant message is always text | |
# Check if assistant message was an HTML audio tag, if so, send a placeholder to LLM | |
if "<audio controls src=" in hist_assistant: | |
messages_for_llm.append({"role": "assistant", "content": "[Audio was played in response to the previous message]"}) | |
else: | |
messages_for_llm.append({"role": "assistant", "content": hist_assistant}) | |
messages_for_llm.append({"role": "user", "content": user_content}) | |
print(f"Latest user message appended (content type: {type(user_content)})") | |
# print(f"Messages for LLM: {json.dumps(messages_for_llm, indent=2)}") # Very verbose | |
response_text = "" | |
print(f"Sending request to {provider} provider for general response.") | |
parameters = { | |
"max_tokens": max_tokens, | |
"temperature": temperature, | |
"top_p": top_p, | |
"frequency_penalty": frequency_penalty, | |
} | |
if seed is not None: | |
parameters["seed"] = seed | |
try: | |
stream = client_for_llm.chat_completion( | |
model=model_to_use, | |
messages=messages_for_llm, | |
stream=True, | |
**parameters | |
) | |
print("Streaming LLM response: ", end="", flush=True) | |
for chunk in stream: | |
if hasattr(chunk, 'choices') and len(chunk.choices) > 0: | |
if hasattr(chunk.choices[0], 'delta') and hasattr(chunk.choices[0].delta, 'content'): | |
token_text = chunk.choices[0].delta.content | |
if token_text: | |
print(token_text, end="", flush=True) | |
response_text += token_text | |
yield response_text | |
print() # Newline after streaming | |
except Exception as e: | |
print(f"Error during LLM inference: {e}") | |
response_text += f"\nError during LLM response generation: {str(e)}" | |
yield response_text | |
print("Completed LLM response generation.") | |
# GRADIO UI | |
with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo: | |
chatbot = gr.Chatbot( | |
height=600, | |
show_copy_button=True, | |
placeholder="Select a model and begin chatting. Now supports multiple inference providers, multimodal inputs, and MCP tools", | |
layout="panel", | |
show_label=False, | |
render=False # Delay rendering | |
) | |
print("Chatbot interface created.") | |
with gr.Row(): | |
msg = gr.MultimodalTextbox( | |
placeholder="Type a message or upload images...", | |
show_label=False, | |
container=True, # Ensure it's a container for proper layout | |
scale=12, | |
file_types=["image"], | |
file_count="multiple", | |
sources=["upload"], | |
render=False # Delay rendering | |
) | |
# Render chatbot and message box after defining them | |
chatbot.render() | |
msg.render() | |
with gr.Accordion("Settings", open=False): | |
system_message_box = gr.Textbox( | |
value="You are a helpful AI assistant that can understand images and text.", | |
placeholder="You are a helpful assistant.", | |
label="System Prompt" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
max_tokens_slider = gr.Slider(minimum=1, maximum=4096, value=512, step=1, label="Max tokens") | |
temperature_slider = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature") | |
top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-P") | |
with gr.Column(): | |
frequency_penalty_slider = gr.Slider(minimum=-2.0, maximum=2.0, value=0.0, step=0.1, label="Frequency Penalty") | |
seed_slider = gr.Slider(minimum=-1, maximum=65535, value=-1, step=1, label="Seed (-1 for random)") | |
providers_list = ["hf-inference", "cerebras", "together", "sambanova", "novita", "cohere", "fireworks-ai", "hyperbolic", "nebius"] | |
provider_radio = gr.Radio(choices=providers_list, value="hf-inference", label="Inference Provider") | |
byok_textbox = gr.Textbox(value="", label="BYOK (Bring Your Own Key)", info="Enter a custom Hugging Face API key here. If empty, only 'hf-inference' provider can be used with the default token.", placeholder="Enter your Hugging Face API token", type="password") | |
custom_model_box = gr.Textbox(value="", label="Custom Model", info="(Optional) Provide a Hugging Face model path. Overrides selected featured model.", placeholder="meta-llama/Llama-3.1-70B-Instruct") | |
model_search_box = gr.Textbox(label="Filter Models", placeholder="Search for a featured model...", lines=1) | |
models_list = [ | |
"meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.1-405B-Instruct", "meta-llama/Llama-3.1-70B-Instruct", "meta-llama/Llama-3.1-8B-Instruct", | |
"meta-llama/Llama-3-70B-Instruct", "meta-llama/Llama-3-8B-Instruct", | |
"NousResearch/Hermes-3-Llama-3.1-8B", "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", | |
"mistralai/Mistral-Nemo-Instruct-2407", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.3", "mistralai/Mistral-7B-Instruct-v0.2", | |
"Qwen/Qwen2.5-72B-Instruct", "Qwen/Qwen2-72B-Instruct", "Qwen/Qwen2-57B-A14B-Instruct", "Qwen/Qwen1.5-110B-Chat", | |
"microsoft/Phi-3-medium-128k-instruct", "microsoft/Phi-3-mini-128k-instruct", "microsoft/Phi-3-small-128k-instruct", | |
"google/gemma-2-27b-it", "google/gemma-2-9b-it", | |
"CohereForAI/c4ai-command-r-plus", | |
"deepseek-ai/DeepSeek-V2-Chat", | |
"Snowflake/snowflake-arctic-instruct" | |
] # Keeping your original list, just formatted for readability | |
featured_model_radio = gr.Radio(label="Select a featured model", choices=models_list, value="meta-llama/Llama-3.2-11B-Vision-Instruct", interactive=True) | |
gr.Markdown("[View all Text-to-Text models](https://huggingface.co/models?pipeline_tag=text-generation&sort=trending) | [View all multimodal models](https://huggingface.co/models?pipeline_tag=image-text-to-text&sort=trending)") | |
with gr.Accordion("MCP Settings", open=False): | |
mcp_enabled_checkbox = gr.Checkbox(label="Enable MCP Support", value=False, info="Enable Model Context Protocol support for external tools") | |
with gr.Row(): | |
mcp_server_url = gr.Textbox(label="MCP Server URL", placeholder="https://your-mcp-server.hf.space/gradio_api/mcp/sse") | |
mcp_server_name = gr.Textbox(label="Server Name (Optional)", placeholder="e.g., kokoroTTS") | |
mcp_connect_button = gr.Button("Connect to MCP Server") | |
mcp_status = gr.Textbox(label="MCP Connection Status", placeholder="No MCP servers connected", interactive=False) | |
active_mcp_servers = gr.Dropdown(label="Active MCP Servers for Chat", choices=[], multiselect=True, info="Select which connected MCP servers to use") | |
mcp_mode = gr.Radio(label="MCP Interaction Mode", choices=["Natural Language", "Command Mode"], value="Natural Language", info="How to trigger MCP tools") | |
gr.Markdown(""" | |
### MCP Interaction Modes | |
**Natural Language**: Describe what you want. E.g., "Convert 'Hello' to speech". | |
**Command Mode**: Use `/mcp <server_name> <tool_name> {"param": "value"}`. E.g., `/mcp kokoroTTS text_to_audio_b64 {"text": "Hello world"}`. | |
""") | |
chat_history_state = gr.State([]) # To store the actual history for the LLM | |
def filter_models_choices(search_term): | |
print(f"Filtering models with search term: {search_term}") | |
if not search_term: return gr.update(choices=models_list) | |
filtered = [m for m in models_list if search_term.lower() in m.lower()] | |
print(f"Filtered models: {filtered}") | |
return gr.update(choices=filtered if filtered else models_list, value=featured_model_radio.value if featured_model_radio.value in filtered else (filtered[0] if filtered else models_list[0])) | |
def update_custom_model_from_radio(selected_featured_model): | |
print(f"Featured model selected: {selected_featured_model}") | |
# This function now updates the custom_model_box. | |
# If you want the radio selection to BE the model_to_use unless custom_model_box has text, | |
# then custom_model_box should be cleared or its value used as override. | |
# For now, let's assume custom_model_box is an override. | |
# If you want the radio to directly feed into the selected_model parameter for respond(), | |
# then this function might not be needed or custom_model_box should be used as an override. | |
return selected_featured_model # This updates the custom_model_box with the radio selection. | |
def handle_connect_mcp_server(url, name_suggestion): | |
actual_name, status_msg = connect_to_mcp_server(url, name_suggestion) | |
all_server_names = list(mcp_connections.keys()) | |
# Keep existing selections if possible | |
current_selection = active_mcp_servers.value if active_mcp_servers.value else [] | |
new_selection = [s for s in current_selection if s in all_server_names] | |
if actual_name and actual_name not in new_selection : # Auto-select newly connected server | |
new_selection.append(actual_name) | |
return status_msg, gr.update(choices=all_server_names, value=new_selection) | |
# This function is called when the user submits a message. | |
# It updates the visual chatbot history and prepares the state for the bot. | |
def handle_user_message(user_input_dict, current_chat_history_state): | |
text_content = user_input_dict.get("text", "").strip() | |
files = user_input_dict.get("files", []) # List of file paths | |
# Add to visual history (chatbot component) | |
visual_history_additions = [] | |
# Store for LLM (chat_history_state) | |
# We store the raw dict from MultimodalTextbox for user messages | |
# to correctly reconstruct for the LLM later. | |
current_chat_history_state.append([user_input_dict, None]) | |
# For visual chatbot, create separate entries for text and images | |
if text_content: | |
visual_history_additions.append([text_content, None]) | |
if files: | |
for file_path in files: | |
visual_history_additions.append([ (file_path,), None]) # Gradio Chatbot expects tuple for files | |
return visual_history_additions, current_chat_history_state | |
# This function is called after user message is processed. | |
# It calls the LLM and streams the response. | |
def handle_bot_response( | |
current_chat_history_state, # This is the state with the latest user message | |
sys_msg, max_tok, temp, top_p_val, freq_pen, seed_val, prov, api_key_val, cust_model, | |
search, selected_feat_model, mcp_on, active_servs, mcp_interact_mode | |
): | |
if not current_chat_history_state or current_chat_history_state[-1][1] is not None: | |
# User message not yet added or bot already responded | |
yield current_chat_history_state # Or some empty update | |
return | |
# The user message is the first element of the last item in chat_history_state | |
# It's a dict: {'text': '...', 'files': ['path1', ...]} | |
user_message_dict = current_chat_history_state[-1][0] | |
text_from_user_dict = user_message_dict.get("text", "") | |
files_from_user_dict = user_message_dict.get("files", []) | |
# History for LLM should exclude the current un-responded user message | |
history_for_llm = current_chat_history_state[:-1] | |
# Stream response from LLM | |
full_response = "" | |
for R in respond( | |
message=text_from_user_dict, | |
image_files=files_from_user_dict, | |
history=history_for_llm, # Pass history BEFORE current turn | |
system_message=sys_msg, | |
max_tokens=max_tok, | |
temperature=temp, | |
top_p=top_p_val, | |
frequency_penalty=freq_pen, | |
seed=seed_val, | |
provider=prov, | |
custom_api_key=api_key_val, | |
custom_model=cust_model, | |
model_search_term=search, # This might be redundant if featured_model_radio directly updates custom_model_box | |
selected_model=selected_feat_model, # This is the value from the radio | |
mcp_enabled=mcp_on, | |
active_mcp_servers=active_servs, | |
mcp_interaction_mode=mcp_interact_mode | |
): | |
full_response = R | |
# Update the last item in chat_history_state with bot's response | |
current_chat_history_state[-1][1] = full_response | |
# Update visual chatbot | |
# Need to reconstruct visual history from state | |
visual_history_update = [] | |
for user_turn, bot_turn in current_chat_history_state: | |
# User turn processing | |
user_text_viz = user_turn.get("text", "") | |
user_files_viz = user_turn.get("files", []) | |
if user_text_viz: | |
visual_history_update.append([user_text_viz, None if bot_turn is None and user_turn == current_chat_history_state[-1][0] else bot_turn]) # Add text part | |
for f_path in user_files_viz: | |
visual_history_update.append([(f_path,), None if bot_turn is None and user_turn == current_chat_history_state[-1][0] else bot_turn]) # Add image part | |
# Bot turn processing if user turn was only text and no files | |
if not user_text_viz and not user_files_viz and user_text_viz == "" : # Should not happen with current logic | |
visual_history_update.append(["", bot_turn]) | |
elif not user_files_viz and user_text_viz and bot_turn is not None and visual_history_update[-1][0] == user_text_viz : | |
visual_history_update[-1][1] = bot_turn # Assign bot response to the text part | |
yield visual_history_update, current_chat_history_state | |
# Event handlers | |
msg.submit( | |
handle_user_message, | |
[msg, chat_history_state], | |
[chatbot, chat_history_state], # Update visual chatbot and state | |
queue=True # Use queue for streaming | |
).then( | |
handle_bot_response, | |
[chat_history_state, system_message_box, max_tokens_slider, temperature_slider, top_p_slider, | |
frequency_penalty_slider, seed_slider, provider_radio, byok_textbox, custom_model_box, | |
model_search_box, featured_model_radio, mcp_enabled_checkbox, active_mcp_servers, mcp_mode], | |
[chatbot, chat_history_state] # Update visual chatbot and state again with bot response | |
).then( | |
lambda: gr.update(value={"text": "", "files": []}), # Clear MultimodalTextbox | |
None, | |
[msg], | |
queue=False # No queue for simple UI update | |
) | |
mcp_connect_button.click( | |
handle_connect_mcp_server, | |
[mcp_server_url, mcp_server_name], | |
[mcp_status, active_mcp_servers] | |
) | |
model_search_box.change(fn=filter_models_choices, inputs=model_search_box, outputs=featured_model_radio) | |
# Let radio button directly be the selected_model, custom_model_box is an override | |
# featured_model_radio.change(fn=update_custom_model_from_radio, inputs=featured_model_radio, outputs=custom_model_box) | |
def validate_provider_choice(api_key_val, current_provider_val): | |
if not api_key_val.strip() and current_provider_val != "hf-inference": | |
gr.Info("No custom API key provided. Only 'hf-inference' provider can be used. Switching to 'hf-inference'.") | |
return gr.update(value="hf-inference") | |
return gr.update() # No change needed if valid or key provided | |
byok_textbox.change(fn=validate_provider_choice, inputs=[byok_textbox, provider_radio], outputs=provider_radio) | |
provider_radio.change(fn=validate_provider_choice, inputs=[byok_textbox, provider_radio], outputs=provider_radio) | |
print("Gradio interface initialized.") | |
if __name__ == "__main__": | |
print("Launching the demo application.") | |
demo.queue().launch(show_api=False, mcp_server=False, share=os.environ.get("GRADIO_SHARE", "").lower() == "true") |