File size: 21,933 Bytes
038f313
1cee504
c5a20a4
ea82e64
75bf974
 
 
75d7afe
11de92c
75d7afe
 
 
038f313
db00df1
75d7afe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70d58c7
 
75d7afe
75bf974
 
6a6b98f
75d7afe
70d58c7
 
 
 
 
 
 
 
 
57cb471
70d58c7
75d7afe
70d58c7
 
75d7afe
b47b1e3
57cb471
75d7afe
038f313
75d7afe
 
27c8b8d
 
038f313
 
 
3a64d68
98674ca
9e12544
75bf974
9e12544
75d7afe
 
038f313
75d7afe
 
 
 
 
 
9e12544
75d7afe
 
 
 
 
 
 
 
 
9e12544
75d7afe
 
 
 
 
 
 
 
8f939dc
2d6eaa5
75d7afe
 
 
 
1cee504
75d7afe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cee504
75d7afe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
901bafe
75d7afe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2ae72a
 
 
 
 
8f939dc
 
 
75d7afe
 
 
 
 
75bf974
 
75d7afe
70d58c7
75d7afe
 
70d58c7
8f939dc
75bf974
75d7afe
57fd5c0
75d7afe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f939dc
57fd5c0
75d7afe
 
57fd5c0
75d7afe
 
 
 
 
 
 
 
 
8f939dc
75d7afe
 
8f939dc
75d7afe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f939dc
75d7afe
 
 
 
 
 
 
 
 
 
 
57cb471
75d7afe
d92e5cd
75d7afe
57cb471
75d7afe
b0cbd1c
75d7afe
 
 
75bf974
75d7afe
 
 
 
fdab9dd
75d7afe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e12544
75d7afe
 
 
 
 
8f939dc
11de92c
75d7afe
 
 
 
a9862a1
75d7afe
 
 
769901b
75d7afe
77298b9
75d7afe
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
import gradio as gr
from huggingface_hub import InferenceClient
import os
import json
import base64
from PIL import Image
import io
import atexit

# Ensure smolagents and mcp are installed: pip install "smolagents[mcp]" mcp
from smolagents import ToolCollection, CodeAgent
from smolagents.mcp_client import MCPClient as SmolMCPClient # For connecting to MCP SSE servers

ACCESS_TOKEN = os.getenv("HF_TOKEN")
print("Access token loaded.")

# --- MCP Client Integration ---
mcp_tools_collection = ToolCollection(tools=[]) # Global store for loaded MCP tools
mcp_client_instances = [] # To keep track of client instances for proper closing

DEFAULT_MCP_SERVERS = [
    {"name": "KokoroTTS (Example)", "type": "sse", "url": "https://fdaudens-kokoro-mcp.hf.space/gradio_api/mcp/sse"}
]

def load_mcp_tools(server_configs_list):
    global mcp_tools_collection, mcp_client_instances
    
    # Close any existing client instances before loading new ones
    for client_instance in mcp_client_instances:
        try:
            client_instance.close()
            print(f"Closed existing MCP client: {client_instance}")
        except Exception as e:
            print(f"Error closing existing MCP client {client_instance}: {e}")
    mcp_client_instances = []
    
    all_discovered_tools = []
    if not server_configs_list:
        print("No MCP server configurations provided. Clearing MCP tools.")
        mcp_tools_collection = ToolCollection(tools=all_discovered_tools)
        return

    print(f"Loading MCP tools from {len(server_configs_list)} server configurations...")
    for config in server_configs_list:
        server_name = config.get('name', config.get('url', 'Unknown Server'))
        try:
            if config.get("type") == "sse":
                sse_url = config["url"]
                print(f"Attempting to connect to MCP SSE server: {server_name} at {sse_url}")
                
                # Using SmolMCPClient for SSE servers as shown in documentation
                # The constructor expects server_parameters={"url": sse_url}
                smol_mcp_client = SmolMCPClient(server_parameters={"url": sse_url})
                mcp_client_instances.append(smol_mcp_client) # Keep track to close later
                
                discovered_tools_from_server = smol_mcp_client.get_tools() # Returns a list of Tool objects
                
                if discovered_tools_from_server:
                    all_discovered_tools.extend(list(discovered_tools_from_server))
                    print(f"Discovered {len(discovered_tools_from_server)} tools from {server_name}.")
                else:
                    print(f"No tools discovered from {server_name}.")
            # Add elif for "stdio" type if needed in the future, though it's more complex for Gradio apps
            else:
                print(f"Unsupported MCP server type '{config.get('type')}' for {server_name}. Skipping.")
        except Exception as e:
            print(f"Error loading MCP tools from {server_name}: {e}")
    
    mcp_tools_collection = ToolCollection(tools=all_discovered_tools)
    if mcp_tools_collection and len(mcp_tools_collection.tools) > 0:
        print(f"Successfully loaded a total of {len(mcp_tools_collection.tools)} MCP tools:")
        for tool in mcp_tools_collection.tools:
            print(f"  - {tool.name}: {tool.description[:100]}...") # Print short description
    else:
        print("No MCP tools were loaded, or an error occurred.")

def cleanup_mcp_client_instances_on_exit():
    global mcp_client_instances
    print("Attempting to clean up MCP client instances on application exit...")
    for client_instance in mcp_client_instances:
        try:
            client_instance.close()
            print(f"Closed MCP client: {client_instance}")
        except Exception as e:
            print(f"Error closing MCP client {client_instance} on exit: {e}")
    mcp_client_instances = []
    print("MCP client cleanup finished.")

atexit.register(cleanup_mcp_client_instances_on_exit)
# --- End MCP Client Integration ---

# Function to encode image to base64 (remains the same)
def encode_image(image_path):
    if not image_path:
        print("No image path provided")
        return None
    
    try:
        print(f"Encoding image from path: {image_path}")
        if isinstance(image_path, Image.Image):
            image = image_path
        else:
            image = Image.open(image_path)
        
        if image.mode == 'RGBA':
            image = image.convert('RGB')
        
        buffered = io.BytesIO()
        image.save(buffered, format="JPEG")
        img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
        print("Image encoded successfully")
        return img_str
    except Exception as e:
        print(f"Error encoding image: {e}")
        return None

# Modified respond function
def respond(
    message_input_text, # From multimodal textbox's text part
    image_files_list,   # From multimodal textbox's files part
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
    frequency_penalty,
    seed,
    provider,
    custom_api_key,
    custom_model,    
    model_search_term, # Not directly used in this function but passed by UI
    selected_model     # From radio
):
    global mcp_tools_collection # Access the loaded MCP tools

    print(f"Received message text: {message_input_text}")
    print(f"Received {len(image_files_list) if image_files_list else 0} images")
    # ... (keep other prints for debugging)

    token_to_use = custom_api_key if custom_api_key.strip() != "" else ACCESS_TOKEN
    hf_inference_client = InferenceClient(token=token_to_use, provider=provider)
    print(f"Hugging Face Inference Client initialized with {provider} provider.")

    if seed == -1: seed = None

    # --- Prepare current user message (potentially multimodal) ---
    current_user_content_parts = []
    if message_input_text and message_input_text.strip():
        current_user_content_parts.append({"type": "text", "text": message_input_text.strip()})
    
    if image_files_list:
        for img_path in image_files_list:
            if img_path: # img_path is the path to the uploaded file
                encoded_img = encode_image(img_path)
                if encoded_img:
                    current_user_content_parts.append({
                        "type": "image_url",
                        "image_url": {"url": f"data:image/jpeg;base64,{encoded_img}"}
                    })
    
    if not current_user_content_parts: # If message is truly empty
        print("Skipping empty message.")
        for item in history: yield item # hack to make gradio update with history
        return

    # --- Construct messages for LLM ---
    llm_messages = [{"role": "system", "content": system_message}]
    for hist_user, hist_assistant in history:
        # Assuming history user part is already formatted (string or list of dicts)
        if hist_user:
             # Handle complex history items (tuples of text, list_of_image_paths)
            if isinstance(hist_user, tuple) and len(hist_user) == 2:
                hist_user_text, hist_user_images = hist_user
                hist_user_parts = []
                if hist_user_text: hist_user_parts.append({"type": "text", "text": hist_user_text})
                for img_p in hist_user_images:
                    enc_img = encode_image(img_p)
                    if enc_img: hist_user_parts.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{enc_img}"}})
                if hist_user_parts: llm_messages.append({"role": "user", "content": hist_user_parts})
            elif isinstance(hist_user, str): # Simple text history
                 llm_messages.append({"role": "user", "content": hist_user})
            # else: could be already formatted list of dicts from previous multimodal turn

        if hist_assistant:
            llm_messages.append({"role": "assistant", "content": hist_assistant})
            
    llm_messages.append({"role": "user", "content": current_user_content_parts if len(current_user_content_parts) > 1 else current_user_content_parts[0] if current_user_content_parts else ""})
    
    model_to_use = custom_model.strip() if custom_model.strip() else selected_model
    print(f"Model selected for inference: {model_to_use}")
    
    # --- Agent Logic or Direct LLM Call ---
    active_mcp_tools = list(mcp_tools_collection.tools) if mcp_tools_collection else []

    if active_mcp_tools:
        print(f"MCP tools are active ({len(active_mcp_tools)} tools). Using CodeAgent.")

        # Wrapper for smolagents.CodeAgent to use our configured HF InferenceClient
        class HFClientWrapperForAgent:
            def __init__(self, hf_client, model_id, outer_scope_params):
                self.client = hf_client
                self.model_id = model_id
                self.params = outer_scope_params

            def generate(self, agent_llm_messages, tools=None, tool_choice=None, **kwargs):
                # agent_llm_messages is from the agent. tools/tool_choice also from agent.
                api_params = {
                    "model": self.model_id,
                    "messages": agent_llm_messages,
                    "stream": False, # CodeAgent's .run() expects a full response object
                    "max_tokens": self.params['max_tokens'],
                    "temperature": self.params['temperature'],
                    "top_p": self.params['top_p'],
                    "frequency_penalty": self.params['frequency_penalty'],
                }
                if self.params['seed'] is not None: api_params["seed"] = self.params['seed']
                if tools: api_params["tools"] = tools
                if tool_choice: api_params["tool_choice"] = tool_choice
                
                print(f"Agent's HFClientWrapper calling LLM: {self.model_id}")
                completion = self.client.chat_completion(**api_params)
                return completion

        outer_scope_llm_params = {
            "max_tokens": max_tokens, "temperature": temperature, "top_p": top_p,
            "frequency_penalty": frequency_penalty, "seed": seed
        }
        agent_model_adapter = HFClientWrapperForAgent(hf_inference_client, model_to_use, outer_scope_llm_params)
        
        agent = CodeAgent(tools=active_mcp_tools, model=agent_model_adapter)
        
        # Prime agent with history (all messages except the current user query)
        agent.messages = llm_messages[:-1] 
        
        # CodeAgent.run expects a string query. Extract text from current user message.
        current_query_for_agent = message_input_text.strip() if message_input_text else "User provided image(s)."
        if not current_query_for_agent and image_files_list: # If only image, provide a generic text
            current_query_for_agent = "Describe the image(s) or follow instructions related to them."
        elif not current_query_for_agent and not image_files_list: # Should not happen due to earlier check
             current_query_for_agent = "..."


        print(f"Query for CodeAgent.run: '{current_query_for_agent}' with {len(agent.messages)} history messages.")
        try:
            agent_final_text_response = agent.run(current_query_for_agent)
            # Note: agent.run() is blocking and returns the final string. 
            # It won't stream token by token if tools are used.
            yield agent_final_text_response
            print("Completed response generation via CodeAgent.")
        except Exception as e:
            print(f"Error during CodeAgent execution: {e}")
            yield f"Error using tools: {str(e)}"
        return

    else: # No MCP tools, use original streaming logic
        print("No MCP tools active. Proceeding with direct LLM call (streaming).")
        response_stream_content = ""
        try:
            stream = hf_inference_client.chat_completion(
                model=model_to_use,
                messages=llm_messages,
                stream=True,
                max_tokens=max_tokens, temperature=temperature, top_p=top_p,
                frequency_penalty=frequency_penalty, seed=seed
            )
            for chunk in stream:
                if hasattr(chunk, 'choices') and len(chunk.choices) > 0:
                    delta = chunk.choices[0].delta
                    if hasattr(delta, 'content') and delta.content:
                        token_text = delta.content
                        response_stream_content += token_text
                        yield response_stream_content
            print("\nCompleted streaming response generation.")
        except Exception as e:
            print(f"Error during direct LLM inference: {e}")
            yield response_stream_content + f"\nError: {str(e)}"

# Function to validate provider (remains the same)
def validate_provider(api_key, provider):
    if not api_key.strip() and provider != "hf-inference":
        return gr.update(value="hf-inference")
    return gr.update(value=provider)

# GRADIO UI
with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
    chatbot = gr.Chatbot(
        label="Serverless TextGen Hub",
        height=600, show_copy_button=True, 
        placeholder="Select a model, (optionally) load MCP Tools, and begin chatting.",
        layout="panel",
        bubble_full_width=False
    )
    
    msg_input_box = gr.MultimodalTextbox(
        placeholder="Type a message or upload images...",
        show_label=False, container=False, scale=12,
        file_types=["image"], file_count="multiple", sources=["upload"]
    )
    
    with gr.Accordion("Settings", open=False):
        system_message_box = gr.Textbox(value="You are a helpful AI assistant.", label="System Prompt")
        with gr.Row():
            # ... (max_tokens, temperature, top_p sliders remain the same)
            max_tokens_slider = gr.Slider(1, 4096, value=512, step=1, label="Max tokens")
            temperature_slider = gr.Slider(0.1, 4.0, value=0.7, step=0.1, label="Temperature")
            top_p_slider = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-P")
        with gr.Row():
            # ... (frequency_penalty, seed sliders remain the same)
            frequency_penalty_slider = gr.Slider(-2.0, 2.0, value=0.0, step=0.1, label="Frequency Penalty")
            seed_slider = gr.Slider(-1, 65535, value=-1, step=1, label="Seed (-1 for random)")

        providers_list = ["hf-inference", "cerebras", "together", "sambanova", "novita", "cohere", "fireworks-ai", "hyperbolic", "nebius"]
        provider_radio = gr.Radio(choices=providers_list, value="hf-inference", label="Inference Provider")
        byok_textbox = gr.Textbox(label="BYOK (Hugging Face API Key)", type="password", placeholder="Enter token if not using 'hf-inference'")
        custom_model_box = gr.Textbox(label="Custom Model ID", placeholder="org/model-name (overrides selection below)")
        model_search_box = gr.Textbox(label="Filter Featured Models", placeholder="Search...")
        
        models_list = [ # Keep your extensive model list
            "meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.3-70B-Instruct", 
            # ... (include all your models) ...
            "microsoft/Phi-3-mini-4k-instruct",
        ]
        featured_model_radio = gr.Radio(label="Select a Featured Model", choices=models_list, value="meta-llama/Llama-3.2-11B-Vision-Instruct", interactive=True)
        gr.Markdown("[All Text models](https://huggingface.co/models?pipeline_tag=text-generation) | [All Multimodal models](https://huggingface.co/models?pipeline_tag=image-text-to-text)")

    # --- MCP Client Settings UI ---
    with gr.Accordion("MCP Client Settings (Connect to External Tools)", open=False):
        gr.Markdown("Configure connections to MCP Servers to allow the LLM to use external tools. The LLM will decide when to use these tools based on your prompts.")
        mcp_server_config_input = gr.Textbox(
            label="MCP Server Configurations (JSON Array)",
            info='Example: [{"name": "MyToolServer", "type": "sse", "url": "http://server_url/gradio_api/mcp/sse"}]',
            lines=3,
            placeholder='Enter a JSON list of server configurations here.',
            value=json.dumps(DEFAULT_MCP_SERVERS, indent=2) # Pre-fill with defaults
        )
        mcp_load_status_display = gr.Textbox(label="MCP Load Status", interactive=False)
        load_mcp_tools_btn = gr.Button("Load/Reload MCP Tools")
        
        def handle_load_mcp_tools_click(config_str_from_ui):
            if not config_str_from_ui:
                load_mcp_tools([]) # Clear tools if config is empty
                return "MCP tool loading attempted with empty config. Tools cleared."
            try:
                parsed_configs = json.loads(config_str_from_ui)
                if not isinstance(parsed_configs, list):
                    return "Error: MCP configuration must be a valid JSON list."
                load_mcp_tools(parsed_configs) # Call the main loading function
                
                if mcp_tools_collection and len(mcp_tools_collection.tools) > 0:
                    loaded_tool_names = [t.name for t in mcp_tools_collection.tools]
                    return f"Successfully loaded {len(loaded_tool_names)} MCP tools: {', '.join(loaded_tool_names)}"
                else:
                    return "No MCP tools loaded, or an error occurred during loading. Check console for details."
            except json.JSONDecodeError:
                return "Error: Invalid JSON format in MCP server configurations."
            except Exception as e:
                print(f"Unhandled error in handle_load_mcp_tools_click: {e}")
                return f"Error loading MCP tools: {str(e)}. Check console."

        load_mcp_tools_btn.click(
            handle_load_mcp_tools_click,
            inputs=[mcp_server_config_input],
            outputs=mcp_load_status_display
        )
    # --- End MCP Client Settings UI ---

    # Chat history state (remains the same)
    # chat_history = gr.State([]) # Not explicitly used if chatbot manages history directly

    # Function to filter models (remains the same)
    def filter_models(search_term):
        return gr.update(choices=[m for m in models_list if search_term.lower() in m.lower()])

    # Function to set custom model from radio (remains the same)
    def set_custom_model_from_radio(selected):
        return selected # Updates custom_model_box with the selected featured model

    # Gradio's MultimodalTextbox submit action
    # The `user` function is simplified as msg_input_box directly gives text and files
    # The `bot` function is where the main logic of `respond` is called.
    
    def handle_submit(msg_content_dict, current_chat_history):
        # msg_content_dict = {"text": "...", "files": ["path1", "path2"]}
        text = msg_content_dict.get("text", "")
        files = msg_content_dict.get("files", [])

        # Add user message to history for display
        # For multimodal, we might want to display text and images separately or combined
        user_display_entry = []
        if text:
            user_display_entry.append(text)
        if files:
            # For display, Gradio chatbot can render markdown images
            for f_path in files:
                user_display_entry.append(f"![{os.path.basename(f_path)}]({f_path})")
        
        # Construct a representation for history that `respond` can unpack
        # For simplicity, let's pass text and files separately to `respond`
        # and the history will store the user input as (text, files_list_for_display)
        
        history_entry_user_part = (text, files) # Store as tuple for `respond` to process easily later
        current_chat_history.append([history_entry_user_part, None]) # Add user part, assistant is None for now
        
        # Prepare for streaming response
        # The `respond` function is a generator
        assistant_response_accumulator = ""
        for streamed_chunk in respond(
            text, files, 
            current_chat_history[:-1], # Pass history *before* current turn
            system_message_box.value, max_tokens_slider.value, temperature_slider.value, 
            top_p_slider.value, frequency_penalty_slider.value, seed_slider.value,
            provider_radio.value, byok_textbox.value, custom_model_box.value,
            model_search_box.value, featured_model_radio.value
        ):
            assistant_response_accumulator = streamed_chunk
            current_chat_history[-1][1] = assistant_response_accumulator # Update last assistant message
            yield current_chat_history, {"text": "", "files": []} # Update chatbot, clear input
        
        # Final update after stream (already done by last yield)
        # yield current_chat_history, {"text": "", "files": []}


    msg_input_box.submit(
        handle_submit,
        [msg_input_box, chatbot],
        [chatbot, msg_input_box] # Output to chatbot and clear msg_input_box
    )
    
    model_search_box.change(filter_models, model_search_box, featured_model_radio)
    featured_model_radio.change(set_custom_model_from_radio, featured_model_radio, custom_model_box)
    byok_textbox.change(validate_provider, [byok_textbox, provider_radio], provider_radio)
    provider_radio.change(validate_provider, [byok_textbox, provider_radio], provider_radio)

# Load default MCP tools on startup
load_mcp_tools(DEFAULT_MCP_SERVERS)
print(f"Initial MCP tools loaded: {len(mcp_tools_collection.tools) if mcp_tools_collection else 0} tools.")

print("Gradio interface initialized.")
if __name__ == "__main__":
    print("Launching the Serverless TextGen Hub demo application.")
    demo.launch(show_api=False) # show_api can be True if needed for other purposes