File size: 20,875 Bytes
cb919f0
81b2233
c5a20a4
ea82e64
cb919f0
 
 
 
81b2233
 
 
 
 
 
 
 
8d2c697
cb919f0
717cd1f
cb919f0
81b2233
 
 
 
81286e1
 
 
81b2233
717cd1f
81b2233
 
 
 
81286e1
 
 
 
 
81b2233
81286e1
81b2233
81286e1
 
 
 
 
81b2233
cb919f0
81b2233
 
 
 
717cd1f
 
 
 
 
81b2233
 
 
 
 
cb919f0
81b2233
4db9e4f
81b2233
 
4db9e4f
81b2233
 
 
 
 
 
 
 
 
 
 
 
4fa442d
81b2233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4fa442d
4db9e4f
81b2233
 
 
 
 
 
 
 
 
 
 
4db9e4f
81b2233
 
 
 
 
 
 
 
 
 
 
 
 
 
4db9e4f
81b2233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4db9e4f
81b2233
 
 
 
 
 
 
 
a7fbaae
81286e1
717cd1f
81b2233
 
 
a7fbaae
81b2233
6f66243
81b2233
 
a7fbaae
 
 
 
81286e1
81b2233
717cd1f
cb919f0
81286e1
 
81b2233
8d2c697
81b2233
cb919f0
717cd1f
81286e1
a7fbaae
4fa442d
a7fbaae
 
 
 
 
 
 
717cd1f
 
 
81b2233
 
 
81286e1
 
cb919f0
dc27384
81b2233
 
 
dc27384
4fa442d
 
a7fbaae
81b2233
 
 
 
 
 
 
 
6f66243
717cd1f
81b2233
 
4fa442d
 
81b2233
 
a7fbaae
81b2233
a7fbaae
 
717cd1f
81b2233
 
 
 
4fa442d
81b2233
 
 
 
 
4fa442d
81b2233
 
 
 
 
4fa442d
81b2233
 
 
 
 
 
8d2c697
81b2233
4db9e4f
81b2233
 
8d2c697
 
81b2233
 
 
8d2c697
81b2233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4fa442d
 
81b2233
8d2c697
81b2233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d2c697
81b2233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d2c697
81b2233
 
 
 
 
 
 
 
 
 
 
 
717cd1f
4db9e4f
4fa442d
81b2233
4db9e4f
 
 
 
717cd1f
4db9e4f
81b2233
717cd1f
81b2233
717cd1f
a7fbaae
717cd1f
 
4fa442d
 
 
 
cb919f0
717cd1f
cb919f0
 
717cd1f
81b2233
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
import gradio as gr
from huggingface_hub import InferenceClient # Keep for direct use if needed, though agent will use its own model
import os
import json
import base64
from PIL import Image
import io

# Smolagents imports
from smolagents import CodeAgent, Tool
from smolagents.models import InferenceClientModel as SmolInferenceClientModel
# We'll use PIL.Image directly for opening, AgentImage is for agent's internal typing if needed by a tool
from smolagents.gradio_ui import pull_messages_from_step # For formatting agent steps
from smolagents.memory import ActionStep, FinalAnswerStep, PlanningStep, MemoryStep # For type checking steps
from smolagents.models import ChatMessageStreamDelta # For type checking stream deltas


ACCESS_TOKEN = os.getenv("HF_TOKEN")
print("Access token loaded.")

# Function to encode image to base64 (remains useful if we ever need to pass base64 to a non-smolagent component)
def encode_image(image_path_or_pil):
    if not image_path_or_pil:
        print("No image path or PIL Image provided")
        return None
    
    try:
        # print(f"Encoding image: {type(image_path_or_pil)}") # Debug
        
        if isinstance(image_path_or_pil, Image.Image):
            image = image_path_or_pil
        else: # Assuming it's a path
            image = Image.open(image_path_or_pil)
        
        if image.mode == 'RGBA':
            image = image.convert('RGB')
        
        buffered = io.BytesIO()
        image.save(buffered, format="JPEG") # JPEG is generally smaller for transfer
        img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
        # print("Image encoded successfully") # Debug
        return img_str
    except Exception as e:
        print(f"Error encoding image: {e}")
        return None

# This function will now set up and run the smolagent
def respond(
    message_text, # Text from MultimodalTextbox
    image_file_paths,  # List of file paths from MultimodalTextbox
    gradio_history: list[tuple[str, str]], # Gradio history (for context if needed, agent is stateless per call here)
    system_message_for_agent, # System prompt for the main LLM agent
    max_tokens,
    temperature,
    top_p,
    frequency_penalty,
    seed,
    provider_for_agent_llm, 
    api_key_for_agent_llm, 
    model_id_for_agent_llm,
    model_search_term, # Unused directly by agent logic
    selected_model_for_agent_llm # Fallback model ID
):
    print(f"Respond function called. Message: '{message_text}', Images: {image_file_paths}")

    token_to_use = api_key_for_agent_llm if api_key_for_agent_llm.strip() != "" else ACCESS_TOKEN
    model_to_use = model_id_for_agent_llm.strip() if model_id_for_agent_llm.strip() != "" else selected_model_for_agent_llm

    # --- Initialize the LLM for the CodeAgent ---
    agent_llm_params = {
        "model_id": model_to_use,
        "token": token_to_use,
        # smolagents's InferenceClientModel uses max_tokens for max_new_tokens
        "max_tokens": max_tokens,
        "temperature": temperature if temperature > 0.01 else None, # Some models require temp > 0
        "top_p": top_p if top_p < 1.0 else None, # Often 1.0 means no top_p
        "seed": seed if seed != -1 else None,
    }
    if provider_for_agent_llm and provider_for_agent_llm != "hf-inference":
        agent_llm_params["provider"] = provider_for_agent_llm
    
    # HFIC specific params, add if not default and supported
    if frequency_penalty != 0.0:
         agent_llm_params["frequency_penalty"] = frequency_penalty
         
    agent_llm = SmolInferenceClientModel(**agent_llm_params)
    print(f"Smolagents LLM for agent initialized: model='{model_to_use}', provider='{provider_for_agent_llm or 'default'}'")

    # --- Define Tools for the Agent ---
    agent_tools = []
    try:
        image_gen_tool = Tool.from_space(
            space_id="black-forest-labs/FLUX.1-schnell",
            name="image_generator",
            description="Generates an image from a textual prompt. Input is a single string argument named 'prompt'. Output is an image file path.",
            token=token_to_use 
        )
        agent_tools.append(image_gen_tool)
        print("Image generation tool loaded: black-forest-labs/FLUX.1-schnell")
    except Exception as e:
        print(f"Error loading image generation tool: {e}")
        yield f"Error: Could not load image generation tool. {e}"
        return

    # --- Initialize the CodeAgent ---
    # If system_message_for_agent is empty, CodeAgent will use its default.
    # The default is usually good as it explains how to use tools.
    agent = CodeAgent(
        tools=agent_tools,
        model=agent_llm,
        system_prompt=system_message_for_agent if system_message_for_agent and system_message_for_agent.strip() else None,
        # add_base_tools=True, # Consider adding Python interpreter, etc.
        stream_outputs=True # Important for Gradio streaming
    )
    print("Smolagents CodeAgent initialized.")

    # --- Prepare task and image inputs for the agent ---
    agent_task_text = message_text
    
    pil_images_for_agent = []
    if image_file_paths:
        for file_path in image_file_paths:
            try:
                pil_images_for_agent.append(Image.open(file_path))
            except Exception as e:
                print(f"Error opening image file {file_path} for agent: {e}")
    
    print(f"Agent task: '{agent_task_text}'")
    if pil_images_for_agent:
        print(f"Passing {len(pil_images_for_agent)} image(s) to agent.")

    # --- Run the agent and stream response ---
    # Agent is reset each turn. For conversational memory, agent instance
    # would need to be stored in session_state and agent.run(..., reset=False) used.
    
    current_agent_response_text = ""
    try:
        # The agent.run method returns a generator when stream=True
        for step_item in agent.run(
            task=agent_task_text, 
            images=pil_images_for_agent, 
            stream=True, 
            reset=True # Explicitly reset for stateless operation per call
        ):
            if isinstance(step_item, ChatMessageStreamDelta):
                if step_item.content:
                    current_agent_response_text += step_item.content
                    yield current_agent_response_text # Yield accumulated text
            
            elif isinstance(step_item, (ActionStep, PlanningStep, FinalAnswerStep)):
                # A structured step. Format it for Gradio.
                # pull_messages_from_step yields gr.ChatMessage objects.
                for gradio_chat_msg in pull_messages_from_step(step_item, skip_model_outputs=agent.stream_outputs):
                    # The 'bot' function will handle these gr.ChatMessage objects.
                    yield gradio_chat_msg # Yield the gr.ChatMessage object directly
                current_agent_response_text = "" # Reset text buffer after a structured step
            
            # else:
                # print(f"Unhandled stream item type: {type(step_item)}") # Debug

        # If there's any remaining text not part of a gr.ChatMessage, yield it.
        # This usually shouldn't happen if stream_to_gradio logic is followed,
        # as text deltas should be part of the last gr.ChatMessage or yielded before it.
        # However, if the agent's final textual answer comes as pure deltas after all steps.
        if current_agent_response_text and not isinstance(step_item, FinalAnswerStep):
             # Check if the last yielded item already contains this text
            if not (isinstance(step_item, gr.ChatMessage) and step_item.content == current_agent_response_text):
                 yield current_agent_response_text


    except Exception as e:
        error_message = f"Error during agent execution: {str(e)}"
        print(error_message)
        yield error_message # Yield the error message to be displayed in UI

    print("Agent run completed.")


# Function to validate provider selection based on BYOK
def validate_provider(api_key, provider):
    if not api_key.strip() and provider != "hf-inference":
        return gr.update(value="hf-inference")
    return gr.update(value=provider)

# GRADIO UI
with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
    chatbot = gr.Chatbot(
        height=600, 
        show_copy_button=True, 
        placeholder="Select a model and begin chatting. Now uses smolagents with tools!",
        layout="panel",
        bubble_full_width=False # For better display of images/files
    )
    print("Chatbot interface created.")
    
    msg = gr.MultimodalTextbox(
        placeholder="Type a message or upload images...",
        show_label=False,
        container=False,
        scale=12,
        file_types=["image"],
        file_count="multiple",
        sources=["upload"]
    )
    
    with gr.Accordion("Settings", open=False):
        system_message_box = gr.Textbox(
            value="You are a helpful AI assistant. You can generate images if asked. Be precise with your prompts for image generation.", 
            placeholder="You are a helpful AI assistant.",
            label="System Prompt for Agent"
        )
        
        with gr.Row():
            with gr.Column():
                max_tokens_slider = gr.Slider(minimum=1, maximum=4096, value=1024, step=1, label="Max New Tokens")
                temperature_slider = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.01, label="Temperature")
                top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.01, label="Top-P")
            with gr.Column():
                frequency_penalty_slider = gr.Slider(minimum=-2.0, maximum=2.0, value=0.0, step=0.1, label="Frequency Penalty")
                seed_slider = gr.Slider(minimum=-1, maximum=65535, value=-1, step=1, label="Seed (-1 for random)")
        
        providers_list = [
            "hf-inference", "cerebras", "together", "sambanova", "novita", 
            "cohere", "fireworks-ai", "hyperbolic", "nebius",
        ]
        provider_radio = gr.Radio(choices=providers_list, value="hf-inference", label="Inference Provider for Agent's LLM")
        byok_textbox = gr.Textbox(value="", label="BYOK (Your HF Token or Provider API Key)", info="Enter API key for the selected provider. Uses HF_TOKEN if empty.", placeholder="Enter your API token", type="password")
        custom_model_box = gr.Textbox(value="", label="Custom Model ID for Agent's LLM", info="(Optional) Provide a custom model ID. Overrides featured model.", placeholder="meta-llama/Llama-3.3-70B-Instruct")
        model_search_box = gr.Textbox(label="Filter Featured Models", placeholder="Search for a featured model...", lines=1)
        
        models_list = [
            "meta-llama/Llama-3.3-70B-Instruct", "meta-llama/Llama-3.1-70B-Instruct", "meta-llama/Llama-3.0-70B-Instruct",
            "meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.2-3B-Instruct", "meta-llama/Llama-3.2-1B-Instruct",
            "meta-llama/Llama-3.1-8B-Instruct", "NousResearch/Hermes-3-Llama-3.1-8B", "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
            "mistralai/Mistral-Nemo-Instruct-2407", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.3",
            "Qwen/Qwen3-235B-A22B", "Qwen/Qwen3-32B", "Qwen/Qwen2.5-72B-Instruct", "Qwen/Qwen2.5-3B-Instruct",
            "Qwen/Qwen2.5-Coder-32B-Instruct", "microsoft/Phi-3.5-mini-instruct", "microsoft/Phi-3-mini-128k-instruct",
        ]
        featured_model_radio = gr.Radio(label="Select a Featured Model for Agent's LLM", choices=models_list, value="meta-llama/Llama-3.3-70B-Instruct", interactive=True)
        
        gr.Markdown("[View all Text-to-Text models](https://huggingface.co/models?inference_provider=all&pipeline_tag=text-generation&sort=trending) | [View all multimodal models](https://huggingface.co/models?inference_provider=all&pipeline_tag=image-text-to-text&sort=trending)")

    # Chat history state (using gr.State to manage it properly)
    # The chatbot's value itself will be the history display.
    # We might need a separate gr.State if agent needs to be conversational across turns.
    # For now, agent is stateless per turn.

    # Function for the chat interface
    def user(user_multimodal_input_dict, history):
        print(f"User input: {user_multimodal_input_dict}")
        text_content = user_multimodal_input_dict.get("text", "")
        files = user_multimodal_input_dict.get("files", [])
        
        user_display_parts = []
        if text_content and text_content.strip():
            user_display_parts.append(text_content)
        for file_path_obj in files: # file_path_obj is a tempfile._TemporaryFileWrapper
            user_display_parts.append((file_path_obj.name, os.path.basename(file_path_obj.name)))
            
        if not user_display_parts:
            return history
            
        # Append the user's multimodal message to history for display
        # The actual data (dict) is passed to `bot` function separately.
        history.append([user_display_parts if len(user_display_parts) > 1 else user_display_parts[0], None])
        return history

    def bot(history, system_msg, max_tokens, temperature, top_p, freq_penalty, seed, provider, api_key, custom_model, search_term, selected_model):
        if not history or not history[-1][0]: # If no user input
            yield history
            return

        # The user's input (text and list of file paths) is in history[-1][0]
        # If `user` function stores the dict:
        raw_user_input_dict = history[-1][0] if isinstance(history[-1][0], dict) else {"text": str(history[-1][0]), "files": []}
        
        # If `user` function stores formatted display parts:
        # We need to reconstruct or rely on msg input to bot.
        # For now, assuming msg.submit passes the raw dict.
        # Let's adjust the Gradio flow to pass `msg` directly to `bot` as well.

        # The `msg` variable in `msg.submit` holds the raw MultimodalTextbox output.
        # We need to pass this raw dict to `respond`.
        # The `history` is for display.
        
        # This part is tricky as `bot` gets `history` which is already formatted for display.
        # A common pattern is to pass `msg` (raw input) also to `bot`.
        # Let's assume `history[-1][0]` contains enough info or we adjust `user` fn.
        # For simplicity, let's assume `user` stores the raw dict if needed,
        # or `bot` can parse `history[-1][0]` if it's a string/list of tuples.

        # Let's assume `history[-1][0]` is the raw `user_multimodal_input_dict`
        # This means the `user` function must append it like: `history.append([user_multimodal_input_dict, None])`
        # And the chatbot will display `str(user_multimodal_input_dict)`.
        # This is what the current `user` function does.

        user_input_data = history[-1][0] # This should be the dict from MultimodalTextbox
        text_input_for_agent = user_input_data.get("text", "")
        # Files from MultimodalTextbox are temp file paths
        image_file_paths_for_agent = [f.name for f in user_input_data.get("files", []) if hasattr(f, 'name')]


        history[-1][1] = "" # Initialize assistant's part for streaming
        
        # Buffer for current text stream from agent
        # Handles both pure text deltas and text content from gr.ChatMessage
        current_text_for_turn = ""

        for item in respond(
            message_text=text_input_for_agent,
            image_file_paths=image_file_paths_for_agent,
            gradio_history=history[:-1], # Pass previous turns for context if agent uses it
            system_message_for_agent=system_msg,
            max_tokens=max_tokens, temperature=temperature, top_p=top_p,
            frequency_penalty=freq_penalty, seed=seed,
            provider_for_agent_llm=provider, api_key_for_agent_llm=api_key,
            model_id_for_agent_llm=custom_model,
            model_search_term=search_term, # unused
            selected_model_for_agent_llm=selected_model
        ):
            if isinstance(item, str): # LLM text delta from agent's thought or textual answer
                current_text_for_turn = item 
                history[-1][1] = current_text_for_turn
            elif isinstance(item, gr.ChatMessage):
                # This is a structured step (thought, tool output, image, etc.)
                # We need to append this to the history as a new message or part of current message.
                # For simplicity, let's append its string content to the current turn's assistant message.
                # If it's an image/file, we'll represent it as a markdown link.
                if isinstance(item.content, str):
                    current_text_for_turn = item.content # Replace if it's a full message
                elif isinstance(item.content, dict) and "path" in item.content:
                    # This is typically an image or audio file
                    file_path = item.content["path"]
                    # We need to make this file accessible to Gradio if it's temporary from agent
                    # For now, just put a placeholder.
                    # If it's an output from a tool, the path might be relative to where smolagents saves it.
                    # Gradio needs an absolute path or a URL.
                    # A common pattern is to copy temp files to a static dir served by Gradio or use gr.File.
                    # For now, let's assume Gradio can handle local paths if they are in a folder it knows.
                    # We'll display it as a tuple for Gradio Chatbot.
                    # This means history[-1][1] needs to become a list.
                    
                    # If current_text_for_turn is not empty, make history[-1][1] a list
                    if current_text_for_turn and not isinstance(history[-1][1], list):
                        history[-1][1] = [current_text_for_turn]
                    elif not current_text_for_turn and not isinstance(history[-1][1], list):
                         history[-1][1] = []


                    alt_text = item.metadata.get("title", os.path.basename(file_path)) if item.metadata else os.path.basename(file_path)
                    
                    # Add as new component to the list for current assistant message
                    if isinstance(history[-1][1], list):
                        history[-1][1].append((file_path, alt_text))
                    else: # Should have been made a list above
                        history[-1][1] = [(file_path, alt_text)]
                    
                    current_text_for_turn = "" # Reset text buffer after a file
                
                # If it's not a delta, but a full message, replace the current text
                if not isinstance(history[-1][1], list): # if it hasn't become a list due to file
                    history[-1][1] = current_text_for_turn

            yield history

    # Event handlers
    # `msg.submit`'s first argument is the function to call.
    # Its `inputs` are the Gradio components whose values are passed to the function.
    # Its `outputs` are the Gradio components that are updated by the function's return value.
    # The `user` function now appends the raw dict from MultimodalTextbox to history.
    # The `bot` function takes this history.
    
    # When msg is submitted:
    # 1. Call `user` to update history with user's input. Output is `chatbot`.
    # 2. Then call `bot` with the updated history. Output is `chatbot`.
    # 3. Then clear `msg`
    msg.submit(
        user,
        [msg, chatbot],
        [chatbot], # `user` returns the new history, updating the chatbot display
        queue=False
    ).then(
        bot,
        [chatbot, system_message_box, max_tokens_slider, temperature_slider, top_p_slider, 
         frequency_penalty_slider, seed_slider, provider_radio, byok_textbox, custom_model_box, 
         model_search_box, featured_model_radio],
        [chatbot] # `bot` yields history updates, streaming to chatbot
    ).then(
        lambda: {"text": "", "files": []},  # Clear MultimodalTextbox
        None,
        [msg]
    )
    
    model_search_box.change(fn=filter_models, inputs=model_search_box, outputs=featured_model_radio)
    featured_model_radio.change(fn=set_custom_model_from_radio, inputs=featured_model_radio, outputs=custom_model_box)
    byok_textbox.change(fn=validate_provider, inputs=[byok_textbox, provider_radio], outputs=provider_radio)
    provider_radio.change(fn=validate_provider, inputs=[byok_textbox, provider_radio], outputs=provider_radio)

print("Gradio interface initialized.")

if __name__ == "__main__":
    print("Launching the demo application.")
    demo.launch(show_api=False) # show_api=False for cleaner launch, True for API docs