|
import gradio as gr |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import requests |
|
import io |
|
from PIL import Image |
|
import re |
|
import json |
|
import xml.etree.ElementTree as ET |
|
|
|
class SmolLMWithTools: |
|
def __init__(self): |
|
|
|
self.checkpoint = "HuggingFaceTB/SmolLM3-3B" |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(f"Loading SmolLM3 on {self.device}...") |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.checkpoint) |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
self.checkpoint, |
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 |
|
).to(self.device) |
|
|
|
|
|
self.hf_token = None |
|
self.flux_api_url = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell" |
|
|
|
|
|
self.tools = [ |
|
{ |
|
"name": "generate_image", |
|
"description": "Generate an image using AI based on a text description. Use this when the user asks for images, pictures, drawings, or visual content.", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"prompt": { |
|
"type": "string", |
|
"description": "A detailed description of the image to generate. Be specific and descriptive." |
|
} |
|
}, |
|
"required": ["prompt"] |
|
} |
|
} |
|
] |
|
|
|
print("Model loaded successfully!") |
|
|
|
def set_hf_token(self, token): |
|
"""Set the Hugging Face API token""" |
|
self.hf_token = token |
|
return "β
HF Token set successfully!" |
|
|
|
def generate_image_tool(self, prompt): |
|
"""Tool function to generate images using FLUX""" |
|
if not self.hf_token: |
|
return {"success": False, "error": "HF token not set", "image": None} |
|
|
|
headers = {"Authorization": f"Bearer {self.hf_token}"} |
|
data = {"inputs": prompt} |
|
|
|
try: |
|
response = requests.post(self.flux_api_url, headers=headers, json=data) |
|
|
|
if response.status_code == 200: |
|
image = Image.open(io.BytesIO(response.content)) |
|
return {"success": True, "message": f"Successfully generated image: {prompt}", "image": image} |
|
elif response.status_code == 503: |
|
return {"success": False, "error": "Model is loading, please try again", "image": None} |
|
else: |
|
return {"success": False, "error": f"API error: {response.status_code}", "image": None} |
|
|
|
except Exception as e: |
|
return {"success": False, "error": str(e), "image": None} |
|
|
|
def parse_tool_calls(self, text): |
|
"""Parse tool calls from model output""" |
|
tool_calls = [] |
|
|
|
|
|
tool_call_pattern = r'<tool_call>\s*<invoke name="([^"]+)">\s*<parameter name="([^"]+)">([^<]+)</parameter>\s*</invoke>\s*</tool_call>' |
|
matches = re.findall(tool_call_pattern, text, re.DOTALL) |
|
|
|
for match in matches: |
|
tool_name, param_name, param_value = match |
|
tool_calls.append({ |
|
"name": tool_name, |
|
"parameters": {param_name: param_value.strip()} |
|
}) |
|
|
|
return tool_calls |
|
|
|
def execute_tool_call(self, tool_call): |
|
"""Execute a tool call and return results""" |
|
tool_name = tool_call["name"] |
|
parameters = tool_call["parameters"] |
|
|
|
if tool_name == "generate_image": |
|
prompt = parameters.get("prompt", "") |
|
return self.generate_image_tool(prompt) |
|
else: |
|
return {"success": False, "error": f"Unknown tool: {tool_name}"} |
|
|
|
def chat_with_tools(self, messages): |
|
"""Generate response with tool calling capability""" |
|
try: |
|
|
|
inputs = self.tokenizer.apply_chat_template( |
|
messages, |
|
enable_thinking=False, |
|
xml_tools=self.tools, |
|
add_generation_prompt=True, |
|
tokenize=True, |
|
return_tensors="pt" |
|
) |
|
|
|
inputs = inputs.to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.model.generate( |
|
inputs, |
|
max_new_tokens=1024, |
|
temperature=0.7, |
|
do_sample=True, |
|
pad_token_id=self.tokenizer.eos_token_id |
|
) |
|
|
|
|
|
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
prompt_text = self.tokenizer.decode(inputs[0], skip_special_tokens=True) |
|
new_content = full_response[len(prompt_text):].strip() |
|
|
|
return new_content |
|
|
|
except Exception as e: |
|
return f"Error generating response: {str(e)}" |
|
|
|
def process_conversation(self, user_message, history, hf_token): |
|
"""Process a conversation turn with potential tool calls""" |
|
if hf_token and not self.hf_token: |
|
self.set_hf_token(hf_token) |
|
|
|
|
|
messages = [] |
|
for h in history: |
|
messages.append({"role": "user", "content": h[0]}) |
|
if h[1]: |
|
messages.append({"role": "assistant", "content": h[1]}) |
|
|
|
messages.append({"role": "user", "content": user_message}) |
|
|
|
|
|
assistant_response = self.chat_with_tools(messages) |
|
|
|
|
|
tool_calls = self.parse_tool_calls(assistant_response) |
|
generated_image = None |
|
final_response = assistant_response |
|
|
|
if tool_calls: |
|
|
|
tool_results = [] |
|
for tool_call in tool_calls: |
|
result = self.execute_tool_call(tool_call) |
|
tool_results.append(result) |
|
|
|
if tool_call["name"] == "generate_image" and result.get("image"): |
|
generated_image = result["image"] |
|
|
|
|
|
messages.append({"role": "assistant", "content": assistant_response}) |
|
|
|
|
|
tool_summary = "\n".join([ |
|
f"Tool {i+1} result: {result.get('message', result.get('error', 'Unknown result'))}" |
|
for i, result in enumerate(tool_results) |
|
]) |
|
|
|
messages.append({"role": "user", "content": f"Tool execution results: {tool_summary}\n\nPlease respond to the user about the results."}) |
|
|
|
|
|
final_response = self.chat_with_tools(messages) |
|
|
|
|
|
history.append([user_message, final_response]) |
|
|
|
return history, "", generated_image |
|
|
|
|
|
chat_system = SmolLMWithTools() |
|
|
|
def create_interface(): |
|
with gr.Blocks(title="SmolLM3 Tool Calling + FLUX", theme=gr.themes.Soft()) as app: |
|
gr.Markdown(""" |
|
# π€π οΈ SmolLM3 with Tool Calling + FLUX |
|
|
|
SmolLM3 can autonomously decide when to generate images based on your conversation! |
|
Just chat naturally - the model will call the image generation tool when appropriate. |
|
|
|
**Examples:** |
|
- "Can you create a picture of a sunset?" |
|
- "I need an image of a robot for my presentation" |
|
- "Draw me a fantasy landscape" |
|
- "Show me what a purple elephant would look like" |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
|
|
hf_token_input = gr.Textbox( |
|
label="π Hugging Face API Token", |
|
placeholder="Enter your HF token for image generation", |
|
type="password" |
|
) |
|
|
|
|
|
chatbot = gr.Chatbot( |
|
label="Chat with SmolLM3 (Tool Calling Enabled)", |
|
height=500, |
|
show_copy_button=True |
|
) |
|
|
|
msg_input = gr.Textbox( |
|
label="Message", |
|
placeholder="Ask for anything - SmolLM3 will decide if it needs to generate an image...", |
|
lines=3 |
|
) |
|
|
|
with gr.Row(): |
|
send_btn = gr.Button("Send π€", variant="primary") |
|
clear_btn = gr.Button("Clear ποΈ") |
|
|
|
with gr.Column(scale=1): |
|
image_output = gr.Image( |
|
label="Generated Images", |
|
height=500 |
|
) |
|
|
|
gr.Markdown(""" |
|
### π§ Available Tools: |
|
- **generate_image**: Creates images from text descriptions |
|
|
|
The model decides autonomously when to use tools based on context! |
|
""") |
|
|
|
|
|
def respond(message, history, hf_token): |
|
if not message.strip(): |
|
return history, "", None |
|
return chat_system.process_conversation(message, history, hf_token) |
|
|
|
|
|
send_btn.click( |
|
respond, |
|
inputs=[msg_input, chatbot, hf_token_input], |
|
outputs=[chatbot, msg_input, image_output] |
|
) |
|
|
|
|
|
msg_input.submit( |
|
respond, |
|
inputs=[msg_input, chatbot, hf_token_input], |
|
outputs=[chatbot, msg_input, image_output] |
|
) |
|
|
|
|
|
clear_btn.click( |
|
lambda: ([], None), |
|
outputs=[chatbot, image_output] |
|
) |
|
|
|
gr.Markdown(""" |
|
### π Setup Instructions: |
|
1. **Get HF Token**: Visit [HuggingFace Tokens](https://huggingface.co/settings/tokens) |
|
2. **Create Token**: Generate a token with "Read" permissions |
|
3. **Enter Token**: Paste it in the field above |
|
4. **Start Chatting**: Ask for anything - images, questions, explanations! |
|
|
|
### π§ How it Works: |
|
- SmolLM3 analyzes your message |
|
- Decides if it needs to call tools |
|
- Generates appropriate tool calls |
|
- Executes the tools and responds with results |
|
|
|
**The AI is in full control of when and how to use tools!** |
|
""") |
|
|
|
return app |
|
|
|
if __name__ == "__main__": |
|
app = create_interface() |
|
app.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=False, |
|
debug=True |
|
) |