Spaces:
Running
Running
File size: 19,236 Bytes
cb919f0 c5a20a4 ea82e64 cb919f0 75bf974 cb919f0 e45083a cb919f0 e45083a cb919f0 109f11f cb919f0 e45083a cb919f0 e45083a cb919f0 109f11f cb919f0 109f11f cb919f0 109f11f cb919f0 109f11f cb919f0 109f11f cb919f0 e45083a cb919f0 e45083a cb919f0 e45083a cb919f0 109f11f cb919f0 e45083a cb919f0 1cee504 cb919f0 e45083a cb919f0 e45083a cb919f0 11de92c cb919f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 |
import gradio as gr
from huggingface_hub import InferenceClient
import os
import json
import base64
from PIL import Image
import io
import atexit
from smolagents import ToolCollection, CodeAgent
from smolagents.mcp_client import MCPClient as SmolMCPClient
ACCESS_TOKEN = os.getenv("HF_TOKEN")
print("Access token loaded.")
mcp_tools_collection = ToolCollection(tools=[])
mcp_client_instances = []
DEFAULT_MCP_SERVERS = [
{"name": "KokoroTTS (Example)", "type": "sse", "url": "https://fdaudens-kokoro-mcp.hf.space/gradio_api/mcp/sse"}
]
def load_mcp_tools(server_configs_list):
global mcp_tools_collection, mcp_client_instances
# No explicit close for SmolMCPClient instances as it's not available directly
# Rely on script termination or GC for now.
# If you were using ToolCollection per server: tc.close() would be the way.
print(f"Clearing {len(mcp_client_instances)} previous MCP client instance references.")
mcp_client_instances = [] # Clear references; old objects will be GC'd if not referenced elsewhere
all_discovered_tools = []
if not server_configs_list:
print("No MCP server configurations provided. Clearing MCP tools.")
mcp_tools_collection = ToolCollection(tools=all_discovered_tools)
return
print(f"Loading MCP tools from {len(server_configs_list)} server configurations...")
for config in server_configs_list:
server_name = config.get('name', config.get('url', 'Unknown Server'))
try:
if config.get("type") == "sse":
sse_url = config["url"]
print(f"Attempting to connect to MCP SSE server: {server_name} at {sse_url}")
smol_mcp_client = SmolMCPClient(server_parameters={"url": sse_url})
mcp_client_instances.append(smol_mcp_client)
discovered_tools_from_server = smol_mcp_client.get_tools()
if discovered_tools_from_server:
all_discovered_tools.extend(list(discovered_tools_from_server))
print(f"Discovered {len(discovered_tools_from_server)} tools from {server_name}.")
else:
print(f"No tools discovered from {server_name}.")
else:
print(f"Unsupported MCP server type '{config.get('type')}' for {server_name}. Skipping.")
except Exception as e:
print(f"Error loading MCP tools from {server_name}: {e}")
mcp_tools_collection = ToolCollection(tools=all_discovered_tools)
if mcp_tools_collection and len(mcp_tools_collection.tools) > 0:
print(f"Successfully loaded a total of {len(mcp_tools_collection.tools)} MCP tools:")
for tool in mcp_tools_collection.tools:
print(f" - {tool.name}: {tool.description[:100]}...")
else:
print("No MCP tools were loaded, or an error occurred.")
def cleanup_mcp_client_instances_on_exit():
global mcp_client_instances
print("Attempting to clear MCP client instance references on application exit...")
# No explicit close called here as per previous fix
mcp_client_instances = []
print("MCP client instance reference cleanup finished.")
atexit.register(cleanup_mcp_client_instances_on_exit)
def encode_image(image_path):
if not image_path: return None
try:
image = Image.open(image_path) if not isinstance(image_path, Image.Image) else image_path
if image.mode == 'RGBA': image = image.convert('RGB')
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
except Exception as e:
print(f"Error encoding image {image_path}: {e}")
return None
def respond(
message_input_text,
image_files_list,
history: list[tuple[str, str]], # history will be list of (user_str_display, assistant_str_display)
system_message,
max_tokens,
temperature,
top_p,
frequency_penalty,
seed,
provider,
custom_api_key,
custom_model,
model_search_term,
selected_model
):
global mcp_tools_collection
print(f"Respond: Text='{message_input_text}', Images={len(image_files_list) if image_files_list else 0}")
token_to_use = custom_api_key if custom_api_key.strip() else ACCESS_TOKEN
hf_inference_client = InferenceClient(token=token_to_use, provider=provider)
if seed == -1: seed = None
current_user_content_parts = []
if message_input_text and message_input_text.strip():
current_user_content_parts.append({"type": "text", "text": message_input_text.strip()})
if image_files_list:
for img_path in image_files_list:
encoded_img = encode_image(img_path)
if encoded_img:
current_user_content_parts.append({
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{encoded_img}"}
})
if not current_user_content_parts:
for item in history: yield item # Should not happen if handle_submit filters empty
return
llm_messages = [{"role": "system", "content": system_message}]
for hist_user_str, hist_assistant in history: # hist_user_str is display string
# For LLM context, we only care about the text part of history if it was multimodal.
# Current image handling is only for the *current* turn.
# If you need to re-process history for multimodal context for LLM, this part needs more logic.
# For now, assuming hist_user_str is sufficient as text context from past turns.
if hist_user_str:
llm_messages.append({"role": "user", "content": hist_user_str})
if hist_assistant:
llm_messages.append({"role": "assistant", "content": hist_assistant})
llm_messages.append({"role": "user", "content": current_user_content_parts if len(current_user_content_parts) > 1 else (current_user_content_parts[0] if current_user_content_parts else "")})
# FIX for Issue 1: 'NoneType' object has no attribute 'strip'
model_to_use = (custom_model.strip() if custom_model else "") or selected_model
print(f"Model selected for inference: {model_to_use}")
active_mcp_tools = list(mcp_tools_collection.tools) if mcp_tools_collection else []
if active_mcp_tools:
print(f"MCP tools are active ({len(active_mcp_tools)} tools). Using CodeAgent.")
class HFClientWrapperForAgent:
def __init__(self, hf_client, model_id, outer_scope_params):
self.client = hf_client
self.model_id = model_id
self.params = outer_scope_params
def generate(self, agent_llm_messages, tools=None, tool_choice=None, **kwargs):
api_params = {
"model": self.model_id, "messages": agent_llm_messages, "stream": False,
"max_tokens": self.params['max_tokens'], "temperature": self.params['temperature'],
"top_p": self.params['top_p'], "frequency_penalty": self.params['frequency_penalty'],
}
if self.params['seed'] is not None: api_params["seed"] = self.params['seed']
if tools: api_params["tools"] = tools
if tool_choice: api_params["tool_choice"] = tool_choice
print(f"Agent's HFClientWrapper calling LLM: {self.model_id} with params: {api_params}")
completion = self.client.chat_completion(**api_params)
# FIX for Issue 2 (Potential): Ensure content is not None for text responses
if completion.choices and completion.choices[0].message and \
completion.choices[0].message.content is None and \
(not completion.choices[0].message.tool_calls or not completion.choices[0].message.tool_calls):
print("Warning (HFClientWrapperForAgent): Model returned None content. Setting to empty string.")
completion.choices[0].message.content = ""
return completion
outer_scope_llm_params = {
"max_tokens": max_tokens, "temperature": temperature, "top_p": top_p,
"frequency_penalty": frequency_penalty, "seed": seed
}
agent_model_adapter = HFClientWrapperForAgent(hf_inference_client, model_to_use, outer_scope_llm_params)
agent = CodeAgent(tools=active_mcp_tools, model=agent_model_adapter, messages_constructor=lambda: llm_messages[:-1].copy()) # Prime with history
current_query_for_agent = message_input_text.strip() if message_input_text else "User provided image(s)."
if not current_query_for_agent and image_files_list:
current_query_for_agent = "Process the provided image(s) or follow related instructions."
elif not current_query_for_agent and not image_files_list:
current_query_for_agent = "..." # Should be caught by earlier check
print(f"Query for CodeAgent.run: '{current_query_for_agent}' with {len(llm_messages)-1} history messages for priming.")
try:
agent_final_text_response = agent.run(current_query_for_agent)
yield agent_final_text_response
print("Completed response generation via CodeAgent.")
except Exception as e:
print(f"Error during CodeAgent execution: {e}") # This will now print the actual underlying error
yield f"Error using tools: {str(e)}" # The str(e) might be the user-facing error
return
else:
print("No MCP tools active. Proceeding with direct LLM call (streaming).")
response_stream_content = ""
try:
stream = hf_inference_client.chat_completion(
model=model_to_use, messages=llm_messages, stream=True,
max_tokens=max_tokens, temperature=temperature, top_p=top_p,
frequency_penalty=frequency_penalty, seed=seed
)
for chunk in stream:
if hasattr(chunk, 'choices') and len(chunk.choices) > 0:
delta = chunk.choices[0].delta
if hasattr(delta, 'content') and delta.content:
token_text = delta.content
response_stream_content += token_text
yield response_stream_content
print("\nCompleted streaming response generation.")
except Exception as e:
print(f"Error during direct LLM inference: {e}")
yield response_stream_content + f"\nError: {str(e)}"
def validate_provider(api_key, provider):
if not api_key.strip() and provider != "hf-inference":
return gr.update(value="hf-inference")
return gr.update(value=provider)
with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
# UserWarning for type='tuples' is known. Consider changing to type='messages' later for robustness.
chatbot = gr.Chatbot(
label="Serverless TextGen Hub", height=600, show_copy_button=True,
placeholder="Select a model, (optionally) load MCP Tools, and begin chatting.",
layout="panel", bubble_full_width=False
)
msg_input_box = gr.MultimodalTextbox(
placeholder="Type a message or upload images...", show_label=False,
container=False, scale=12, file_types=["image"],
file_count="multiple", sources=["upload"]
)
with gr.Accordion("Settings", open=False):
system_message_box = gr.Textbox(value="You are a helpful AI assistant.", label="System Prompt")
with gr.Row():
max_tokens_slider = gr.Slider(1, 4096, value=512, step=1, label="Max tokens")
temperature_slider = gr.Slider(0.1, 4.0, value=0.7, step=0.1, label="Temperature")
top_p_slider = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-P")
with gr.Row():
frequency_penalty_slider = gr.Slider(-2.0, 2.0, value=0.0, step=0.1, label="Frequency Penalty")
seed_slider = gr.Slider(-1, 65535, value=-1, step=1, label="Seed (-1 for random)")
providers_list = ["hf-inference", "cerebras", "together", "sambanova", "novita", "cohere", "fireworks-ai", "hyperbolic", "nebius"]
provider_radio = gr.Radio(choices=providers_list, value="hf-inference", label="Inference Provider")
byok_textbox = gr.Textbox(label="BYOK (Hugging Face API Key)", type="password", placeholder="Enter token if not using 'hf-inference'")
custom_model_box = gr.Textbox(label="Custom Model ID", placeholder="org/model-name (overrides selection below)")
model_search_box = gr.Textbox(label="Filter Featured Models", placeholder="Search...")
models_list = [
"meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.3-70B-Instruct",
"meta-llama/Llama-3.1-70B-Instruct", "meta-llama/Llama-3.0-70B-Instruct",
"meta-llama/Llama-3.2-3B-Instruct", "meta-llama/Llama-3.2-1B-Instruct",
"meta-llama/Llama-3.1-8B-Instruct", "NousResearch/Hermes-3-Llama-3.1-8B",
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", "mistralai/Mistral-Nemo-Instruct-2407",
"mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.3",
"mistralai/Mistral-7B-Instruct-v0.2", "Qwen/Qwen3-235B-A22B", "Qwen/Qwen3-32B",
"Qwen/Qwen2.5-72B-Instruct", "Qwen/Qwen2.5-3B-Instruct", "Qwen/Qwen2.5-0.5B-Instruct",
"Qwen/QwQ-32B", "Qwen/Qwen2.5-Coder-32B-Instruct", "microsoft/Phi-3.5-mini-instruct",
"microsoft/Phi-3-mini-128k-instruct", "microsoft/Phi-3-mini-4k-instruct",
]
featured_model_radio = gr.Radio(label="Select a Featured Model", choices=models_list, value="meta-llama/Llama-3.2-11B-Vision-Instruct", interactive=True)
gr.Markdown("[All Text models](https://huggingface.co/models?pipeline_tag=text-generation) | [All Multimodal models](https://huggingface.co/models?pipeline_tag=image-text-to-text)")
with gr.Accordion("MCP Client Settings (Connect to External Tools)", open=False):
gr.Markdown("Configure connections to MCP Servers to allow the LLM to use external tools. The LLM will decide when to use these tools based on your prompts.")
mcp_server_config_input = gr.Textbox(
label="MCP Server Configurations (JSON Array)",
info='Example: [{"name": "MyToolServer", "type": "sse", "url": "http://server_url/gradio_api/mcp/sse"}]',
lines=3, placeholder='Enter a JSON list of server configurations here.',
value=json.dumps(DEFAULT_MCP_SERVERS, indent=2)
)
mcp_load_status_display = gr.Textbox(label="MCP Load Status", interactive=False)
load_mcp_tools_btn = gr.Button("Load/Reload MCP Tools")
def handle_load_mcp_tools_click(config_str_from_ui):
if not config_str_from_ui:
load_mcp_tools([])
return "MCP tool loading attempted with empty config. Tools cleared."
try:
parsed_configs = json.loads(config_str_from_ui)
if not isinstance(parsed_configs, list): return "Error: MCP configuration must be a valid JSON list."
load_mcp_tools(parsed_configs)
if mcp_tools_collection and len(mcp_tools_collection.tools) > 0:
loaded_tool_names = [t.name for t in mcp_tools_collection.tools]
return f"Successfully loaded {len(loaded_tool_names)} MCP tools: {', '.join(loaded_tool_names)}"
else: return "No MCP tools loaded, or an error occurred. Check console for details."
except json.JSONDecodeError: return "Error: Invalid JSON format in MCP server configurations."
except Exception as e:
print(f"Unhandled error in handle_load_mcp_tools_click: {e}")
return f"Error loading MCP tools: {str(e)}. Check console."
load_mcp_tools_btn.click(handle_load_mcp_tools_click, inputs=[mcp_server_config_input], outputs=mcp_load_status_display)
def filter_models(search_term):
return gr.update(choices=[m for m in models_list if search_term.lower() in m.lower()])
def set_custom_model_from_radio(selected):
return selected
def handle_submit(msg_content_dict, current_chat_history):
text = msg_content_dict.get("text", "").strip()
files = msg_content_dict.get("files", []) # list of file paths
if not text and not files: # Skip if both are empty
print("Skipping empty submission from multimodal textbox.")
# Yield current history to prevent Gradio from complaining about no output
yield current_chat_history, {"text": "", "files": []} # Clear input
return
# FIX for Issue 4: Pydantic FileMessage error by ensuring user part of history is a string
user_display_parts = []
if text:
user_display_parts.append(text)
if files:
for f_path in files:
base_name = os.path.basename(f_path) if f_path else "file"
f_path_str = f_path if f_path else ""
user_display_parts.append(f"\n")
user_display_message_for_chatbot = " ".join(user_display_parts).strip()
current_chat_history.append([user_display_message_for_chatbot, None])
# Prepare history for respond function (ensure user part is string)
history_for_respond = []
for user_h, assistant_h in current_chat_history[:-1]: # History before current turn
history_for_respond.append((str(user_h) if user_h is not None else "", assistant_h))
assistant_response_accumulator = ""
for streamed_chunk in respond(
text, files,
history_for_respond,
system_message_box.value, max_tokens_slider.value, temperature_slider.value,
top_p_slider.value, frequency_penalty_slider.value, seed_slider.value,
provider_radio.value, byok_textbox.value, custom_model_box.value,
model_search_box.value, featured_model_radio.value
):
assistant_response_accumulator = streamed_chunk
current_chat_history[-1][1] = assistant_response_accumulator
yield current_chat_history, {"text": "", "files": []}
msg_input_box.submit(
handle_submit,
[msg_input_box, chatbot],
[chatbot, msg_input_box]
)
model_search_box.change(filter_models, model_search_box, featured_model_radio)
featured_model_radio.change(set_custom_model_from_radio, featured_model_radio, custom_model_box)
byok_textbox.change(validate_provider, [byok_textbox, provider_radio], provider_radio)
provider_radio.change(validate_provider, [byok_textbox, provider_radio], provider_radio)
load_mcp_tools(DEFAULT_MCP_SERVERS) # Load defaults on startup
print(f"Initial MCP tools loaded: {len(mcp_tools_collection.tools) if mcp_tools_collection else 0} tools.")
print("Gradio interface initialized.")
if __name__ == "__main__":
print("Launching the Serverless TextGen Hub demo application.")
demo.launch(show_api=False) |