File size: 19,236 Bytes
cb919f0
 
c5a20a4
ea82e64
cb919f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75bf974
cb919f0
 
 
 
 
e45083a
cb919f0
 
 
 
 
 
 
 
 
e45083a
cb919f0
 
 
 
 
 
 
 
 
 
 
109f11f
cb919f0
e45083a
cb919f0
e45083a
cb919f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109f11f
cb919f0
109f11f
cb919f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109f11f
cb919f0
 
 
 
109f11f
cb919f0
 
 
 
 
 
 
 
109f11f
cb919f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e45083a
cb919f0
 
 
 
 
 
 
 
e45083a
cb919f0
 
 
e45083a
cb919f0
 
 
 
109f11f
cb919f0
 
 
e45083a
cb919f0
 
 
 
 
1cee504
cb919f0
 
 
 
 
 
 
 
 
 
e45083a
cb919f0
e45083a
cb919f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11de92c
cb919f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
import gradio as gr
from huggingface_hub import InferenceClient
import os
import json
import base64
from PIL import Image
import io
import atexit

from smolagents import ToolCollection, CodeAgent
from smolagents.mcp_client import MCPClient as SmolMCPClient

ACCESS_TOKEN = os.getenv("HF_TOKEN")
print("Access token loaded.")

mcp_tools_collection = ToolCollection(tools=[])
mcp_client_instances = []

DEFAULT_MCP_SERVERS = [
    {"name": "KokoroTTS (Example)", "type": "sse", "url": "https://fdaudens-kokoro-mcp.hf.space/gradio_api/mcp/sse"}
]

def load_mcp_tools(server_configs_list):
    global mcp_tools_collection, mcp_client_instances
    
    # No explicit close for SmolMCPClient instances as it's not available directly
    # Rely on script termination or GC for now.
    # If you were using ToolCollection per server: tc.close() would be the way.
    print(f"Clearing {len(mcp_client_instances)} previous MCP client instance references.")
    mcp_client_instances = [] # Clear references; old objects will be GC'd if not referenced elsewhere
    
    all_discovered_tools = []
    if not server_configs_list:
        print("No MCP server configurations provided. Clearing MCP tools.")
        mcp_tools_collection = ToolCollection(tools=all_discovered_tools)
        return

    print(f"Loading MCP tools from {len(server_configs_list)} server configurations...")
    for config in server_configs_list:
        server_name = config.get('name', config.get('url', 'Unknown Server'))
        try:
            if config.get("type") == "sse":
                sse_url = config["url"]
                print(f"Attempting to connect to MCP SSE server: {server_name} at {sse_url}")
                smol_mcp_client = SmolMCPClient(server_parameters={"url": sse_url})
                mcp_client_instances.append(smol_mcp_client)
                discovered_tools_from_server = smol_mcp_client.get_tools()
                if discovered_tools_from_server:
                    all_discovered_tools.extend(list(discovered_tools_from_server))
                    print(f"Discovered {len(discovered_tools_from_server)} tools from {server_name}.")
                else:
                    print(f"No tools discovered from {server_name}.")
            else:
                print(f"Unsupported MCP server type '{config.get('type')}' for {server_name}. Skipping.")
        except Exception as e:
            print(f"Error loading MCP tools from {server_name}: {e}")
    
    mcp_tools_collection = ToolCollection(tools=all_discovered_tools)
    if mcp_tools_collection and len(mcp_tools_collection.tools) > 0:
        print(f"Successfully loaded a total of {len(mcp_tools_collection.tools)} MCP tools:")
        for tool in mcp_tools_collection.tools:
            print(f"  - {tool.name}: {tool.description[:100]}...")
    else:
        print("No MCP tools were loaded, or an error occurred.")

def cleanup_mcp_client_instances_on_exit():
    global mcp_client_instances
    print("Attempting to clear MCP client instance references on application exit...")
    # No explicit close called here as per previous fix
    mcp_client_instances = []
    print("MCP client instance reference cleanup finished.")

atexit.register(cleanup_mcp_client_instances_on_exit)

def encode_image(image_path):
    if not image_path: return None
    try:
        image = Image.open(image_path) if not isinstance(image_path, Image.Image) else image_path
        if image.mode == 'RGBA': image = image.convert('RGB')
        buffered = io.BytesIO()
        image.save(buffered, format="JPEG")
        return base64.b64encode(buffered.getvalue()).decode("utf-8")
    except Exception as e:
        print(f"Error encoding image {image_path}: {e}")
        return None

def respond(
    message_input_text,
    image_files_list,
    history: list[tuple[str, str]], # history will be list of (user_str_display, assistant_str_display)
    system_message,
    max_tokens,
    temperature,
    top_p,
    frequency_penalty,
    seed,
    provider,
    custom_api_key,
    custom_model,    
    model_search_term,
    selected_model
):
    global mcp_tools_collection
    print(f"Respond: Text='{message_input_text}', Images={len(image_files_list) if image_files_list else 0}")

    token_to_use = custom_api_key if custom_api_key.strip() else ACCESS_TOKEN
    hf_inference_client = InferenceClient(token=token_to_use, provider=provider)
    if seed == -1: seed = None

    current_user_content_parts = []
    if message_input_text and message_input_text.strip():
        current_user_content_parts.append({"type": "text", "text": message_input_text.strip()})
    if image_files_list:
        for img_path in image_files_list:
            encoded_img = encode_image(img_path)
            if encoded_img:
                current_user_content_parts.append({
                    "type": "image_url",
                    "image_url": {"url": f"data:image/jpeg;base64,{encoded_img}"}
                })
    if not current_user_content_parts:
        for item in history: yield item # Should not happen if handle_submit filters empty
        return

    llm_messages = [{"role": "system", "content": system_message}]
    for hist_user_str, hist_assistant in history: # hist_user_str is display string
        # For LLM context, we only care about the text part of history if it was multimodal.
        # Current image handling is only for the *current* turn.
        # If you need to re-process history for multimodal context for LLM, this part needs more logic.
        # For now, assuming hist_user_str is sufficient as text context from past turns.
        if hist_user_str:
             llm_messages.append({"role": "user", "content": hist_user_str})
        if hist_assistant:
            llm_messages.append({"role": "assistant", "content": hist_assistant})
            
    llm_messages.append({"role": "user", "content": current_user_content_parts if len(current_user_content_parts) > 1 else (current_user_content_parts[0] if current_user_content_parts else "")})
    
    # FIX for Issue 1: 'NoneType' object has no attribute 'strip'
    model_to_use = (custom_model.strip() if custom_model else "") or selected_model
    print(f"Model selected for inference: {model_to_use}")
    
    active_mcp_tools = list(mcp_tools_collection.tools) if mcp_tools_collection else []

    if active_mcp_tools:
        print(f"MCP tools are active ({len(active_mcp_tools)} tools). Using CodeAgent.")
        class HFClientWrapperForAgent:
            def __init__(self, hf_client, model_id, outer_scope_params):
                self.client = hf_client
                self.model_id = model_id
                self.params = outer_scope_params
            def generate(self, agent_llm_messages, tools=None, tool_choice=None, **kwargs):
                api_params = {
                    "model": self.model_id, "messages": agent_llm_messages, "stream": False,
                    "max_tokens": self.params['max_tokens'], "temperature": self.params['temperature'],
                    "top_p": self.params['top_p'], "frequency_penalty": self.params['frequency_penalty'],
                }
                if self.params['seed'] is not None: api_params["seed"] = self.params['seed']
                if tools: api_params["tools"] = tools
                if tool_choice: api_params["tool_choice"] = tool_choice
                
                print(f"Agent's HFClientWrapper calling LLM: {self.model_id} with params: {api_params}")
                completion = self.client.chat_completion(**api_params)
                
                # FIX for Issue 2 (Potential): Ensure content is not None for text responses
                if completion.choices and completion.choices[0].message and \
                   completion.choices[0].message.content is None and \
                   (not completion.choices[0].message.tool_calls or not completion.choices[0].message.tool_calls):
                    print("Warning (HFClientWrapperForAgent): Model returned None content. Setting to empty string.")
                    completion.choices[0].message.content = ""
                return completion

        outer_scope_llm_params = {
            "max_tokens": max_tokens, "temperature": temperature, "top_p": top_p,
            "frequency_penalty": frequency_penalty, "seed": seed
        }
        agent_model_adapter = HFClientWrapperForAgent(hf_inference_client, model_to_use, outer_scope_llm_params)
        agent = CodeAgent(tools=active_mcp_tools, model=agent_model_adapter, messages_constructor=lambda: llm_messages[:-1].copy()) # Prime with history

        current_query_for_agent = message_input_text.strip() if message_input_text else "User provided image(s)."
        if not current_query_for_agent and image_files_list:
            current_query_for_agent = "Process the provided image(s) or follow related instructions."
        elif not current_query_for_agent and not image_files_list:
             current_query_for_agent = "..." # Should be caught by earlier check

        print(f"Query for CodeAgent.run: '{current_query_for_agent}' with {len(llm_messages)-1} history messages for priming.")
        try:
            agent_final_text_response = agent.run(current_query_for_agent)
            yield agent_final_text_response
            print("Completed response generation via CodeAgent.")
        except Exception as e:
            print(f"Error during CodeAgent execution: {e}") # This will now print the actual underlying error
            yield f"Error using tools: {str(e)}" # The str(e) might be the user-facing error
        return
    else:
        print("No MCP tools active. Proceeding with direct LLM call (streaming).")
        response_stream_content = ""
        try:
            stream = hf_inference_client.chat_completion(
                model=model_to_use, messages=llm_messages, stream=True,
                max_tokens=max_tokens, temperature=temperature, top_p=top_p,
                frequency_penalty=frequency_penalty, seed=seed
            )
            for chunk in stream:
                if hasattr(chunk, 'choices') and len(chunk.choices) > 0:
                    delta = chunk.choices[0].delta
                    if hasattr(delta, 'content') and delta.content:
                        token_text = delta.content
                        response_stream_content += token_text
                        yield response_stream_content
            print("\nCompleted streaming response generation.")
        except Exception as e:
            print(f"Error during direct LLM inference: {e}")
            yield response_stream_content + f"\nError: {str(e)}"

def validate_provider(api_key, provider):
    if not api_key.strip() and provider != "hf-inference":
        return gr.update(value="hf-inference")
    return gr.update(value=provider)

with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
    # UserWarning for type='tuples' is known. Consider changing to type='messages' later for robustness.
    chatbot = gr.Chatbot(
        label="Serverless TextGen Hub", height=600, show_copy_button=True,
        placeholder="Select a model, (optionally) load MCP Tools, and begin chatting.",
        layout="panel", bubble_full_width=False
    )
    msg_input_box = gr.MultimodalTextbox(
        placeholder="Type a message or upload images...", show_label=False,
        container=False, scale=12, file_types=["image"],
        file_count="multiple", sources=["upload"]
    )
    with gr.Accordion("Settings", open=False):
        system_message_box = gr.Textbox(value="You are a helpful AI assistant.", label="System Prompt")
        with gr.Row():
            max_tokens_slider = gr.Slider(1, 4096, value=512, step=1, label="Max tokens")
            temperature_slider = gr.Slider(0.1, 4.0, value=0.7, step=0.1, label="Temperature")
            top_p_slider = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-P")
        with gr.Row():
            frequency_penalty_slider = gr.Slider(-2.0, 2.0, value=0.0, step=0.1, label="Frequency Penalty")
            seed_slider = gr.Slider(-1, 65535, value=-1, step=1, label="Seed (-1 for random)")
        providers_list = ["hf-inference", "cerebras", "together", "sambanova", "novita", "cohere", "fireworks-ai", "hyperbolic", "nebius"]
        provider_radio = gr.Radio(choices=providers_list, value="hf-inference", label="Inference Provider")
        byok_textbox = gr.Textbox(label="BYOK (Hugging Face API Key)", type="password", placeholder="Enter token if not using 'hf-inference'")
        custom_model_box = gr.Textbox(label="Custom Model ID", placeholder="org/model-name (overrides selection below)")
        model_search_box = gr.Textbox(label="Filter Featured Models", placeholder="Search...")
        models_list = [
            "meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.3-70B-Instruct",
            "meta-llama/Llama-3.1-70B-Instruct", "meta-llama/Llama-3.0-70B-Instruct",
            "meta-llama/Llama-3.2-3B-Instruct", "meta-llama/Llama-3.2-1B-Instruct",
            "meta-llama/Llama-3.1-8B-Instruct", "NousResearch/Hermes-3-Llama-3.1-8B",
            "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", "mistralai/Mistral-Nemo-Instruct-2407",
            "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.3",
            "mistralai/Mistral-7B-Instruct-v0.2", "Qwen/Qwen3-235B-A22B", "Qwen/Qwen3-32B",
            "Qwen/Qwen2.5-72B-Instruct", "Qwen/Qwen2.5-3B-Instruct", "Qwen/Qwen2.5-0.5B-Instruct",
            "Qwen/QwQ-32B", "Qwen/Qwen2.5-Coder-32B-Instruct", "microsoft/Phi-3.5-mini-instruct",
            "microsoft/Phi-3-mini-128k-instruct", "microsoft/Phi-3-mini-4k-instruct",
        ]
        featured_model_radio = gr.Radio(label="Select a Featured Model", choices=models_list, value="meta-llama/Llama-3.2-11B-Vision-Instruct", interactive=True)
        gr.Markdown("[All Text models](https://huggingface.co/models?pipeline_tag=text-generation) | [All Multimodal models](https://huggingface.co/models?pipeline_tag=image-text-to-text)")

    with gr.Accordion("MCP Client Settings (Connect to External Tools)", open=False):
        gr.Markdown("Configure connections to MCP Servers to allow the LLM to use external tools. The LLM will decide when to use these tools based on your prompts.")
        mcp_server_config_input = gr.Textbox(
            label="MCP Server Configurations (JSON Array)",
            info='Example: [{"name": "MyToolServer", "type": "sse", "url": "http://server_url/gradio_api/mcp/sse"}]',
            lines=3, placeholder='Enter a JSON list of server configurations here.',
            value=json.dumps(DEFAULT_MCP_SERVERS, indent=2)
        )
        mcp_load_status_display = gr.Textbox(label="MCP Load Status", interactive=False)
        load_mcp_tools_btn = gr.Button("Load/Reload MCP Tools")
        
        def handle_load_mcp_tools_click(config_str_from_ui):
            if not config_str_from_ui:
                load_mcp_tools([])
                return "MCP tool loading attempted with empty config. Tools cleared."
            try:
                parsed_configs = json.loads(config_str_from_ui)
                if not isinstance(parsed_configs, list): return "Error: MCP configuration must be a valid JSON list."
                load_mcp_tools(parsed_configs)
                if mcp_tools_collection and len(mcp_tools_collection.tools) > 0:
                    loaded_tool_names = [t.name for t in mcp_tools_collection.tools]
                    return f"Successfully loaded {len(loaded_tool_names)} MCP tools: {', '.join(loaded_tool_names)}"
                else: return "No MCP tools loaded, or an error occurred. Check console for details."
            except json.JSONDecodeError: return "Error: Invalid JSON format in MCP server configurations."
            except Exception as e:
                print(f"Unhandled error in handle_load_mcp_tools_click: {e}")
                return f"Error loading MCP tools: {str(e)}. Check console."
        load_mcp_tools_btn.click(handle_load_mcp_tools_click, inputs=[mcp_server_config_input], outputs=mcp_load_status_display)

    def filter_models(search_term):
        return gr.update(choices=[m for m in models_list if search_term.lower() in m.lower()])
    def set_custom_model_from_radio(selected):
        return selected

    def handle_submit(msg_content_dict, current_chat_history):
        text = msg_content_dict.get("text", "").strip()
        files = msg_content_dict.get("files", []) # list of file paths

        if not text and not files: # Skip if both are empty
            print("Skipping empty submission from multimodal textbox.")
            # Yield current history to prevent Gradio from complaining about no output
            yield current_chat_history, {"text": "", "files": []} # Clear input
            return

        # FIX for Issue 4: Pydantic FileMessage error by ensuring user part of history is a string
        user_display_parts = []
        if text:
            user_display_parts.append(text)
        if files:
            for f_path in files:
                base_name = os.path.basename(f_path) if f_path else "file"
                f_path_str = f_path if f_path else ""
                user_display_parts.append(f"\n![{base_name}]({f_path_str})")
        user_display_message_for_chatbot = " ".join(user_display_parts).strip()
        
        current_chat_history.append([user_display_message_for_chatbot, None])
        
        # Prepare history for respond function (ensure user part is string)
        history_for_respond = []
        for user_h, assistant_h in current_chat_history[:-1]: # History before current turn
             history_for_respond.append((str(user_h) if user_h is not None else "", assistant_h))


        assistant_response_accumulator = ""
        for streamed_chunk in respond(
            text, files, 
            history_for_respond,
            system_message_box.value, max_tokens_slider.value, temperature_slider.value, 
            top_p_slider.value, frequency_penalty_slider.value, seed_slider.value,
            provider_radio.value, byok_textbox.value, custom_model_box.value,
            model_search_box.value, featured_model_radio.value
        ):
            assistant_response_accumulator = streamed_chunk
            current_chat_history[-1][1] = assistant_response_accumulator
            yield current_chat_history, {"text": "", "files": []}
    
    msg_input_box.submit(
        handle_submit,
        [msg_input_box, chatbot],
        [chatbot, msg_input_box]
    )
    model_search_box.change(filter_models, model_search_box, featured_model_radio)
    featured_model_radio.change(set_custom_model_from_radio, featured_model_radio, custom_model_box)
    byok_textbox.change(validate_provider, [byok_textbox, provider_radio], provider_radio)
    provider_radio.change(validate_provider, [byok_textbox, provider_radio], provider_radio)

load_mcp_tools(DEFAULT_MCP_SERVERS) # Load defaults on startup
print(f"Initial MCP tools loaded: {len(mcp_tools_collection.tools) if mcp_tools_collection else 0} tools.")
print("Gradio interface initialized.")

if __name__ == "__main__":
    print("Launching the Serverless TextGen Hub demo application.")
    demo.launch(show_api=False)