|
```python |
|
import json |
|
import torch |
|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
model_name = "Salesforce/xLAM-7b-r" |
|
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype="auto", trust_remote_code=True) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
torch.random.manual_seed(0) |
|
|
|
|
|
task_instruction = """ |
|
Based on the previous context and API request history, generate an API request or a response as an AI assistant.""".strip() |
|
|
|
format_instruction = """ |
|
The output should be of the JSON format, which specifies a list of generated function calls. The example format is as follows, please make sure the parameter type is correct. If no function call is needed, please make |
|
tool_calls an empty list "[]". |
|
``` |
|
{"thought": "the thought process, or an empty string", "tool_calls": [{"name": "api_name1", "arguments": {"argument1": "value1", "argument2": "value2"}}]} |
|
``` |
|
""".strip() |
|
|
|
def convert_to_xlam_tool(tools): |
|
if isinstance(tools, dict): |
|
return { |
|
"name": tools["name"], |
|
"description": tools["description"], |
|
"parameters": {k: v for k, v in tools["parameters"].get("properties", {}).items()} |
|
} |
|
elif isinstance(tools, list): |
|
return [convert_to_xlam_tool(tool) for tool in tools] |
|
else: |
|
return tools |
|
|
|
def build_conversation_history_prompt(conversation_history: str): |
|
parsed_history = [] |
|
for step_data in conversation_history: |
|
parsed_history.append({ |
|
"step_id": step_data["step_id"], |
|
"thought": step_data["thought"], |
|
"tool_calls": step_data["tool_calls"], |
|
"next_observation": step_data["next_observation"], |
|
"user_input": step_data['user_input'] |
|
}) |
|
|
|
history_string = json.dumps(parsed_history) |
|
return f"\n[BEGIN OF HISTORY STEPS]\n{history_string}\n[END OF HISTORY STEPS]\n" |
|
|
|
def build_prompt(task_instruction: str, format_instruction: str, tools: list, query: str, conversation_history: list): |
|
prompt = f"[BEGIN OF TASK INSTRUCTION]\n{task_instruction}\n[END OF TASK INSTRUCTION]\n\n" |
|
prompt += f"[BEGIN OF AVAILABLE TOOLS]\n{json.dumps(tools)}\n[END OF AVAILABLE TOOLS]\n\n" |
|
prompt += f"[BEGIN OF FORMAT INSTRUCTION]\n{format_instruction}\n[END OF FORMAT INSTRUCTION]\n\n" |
|
prompt += f"[BEGIN OF QUERY]\n{query}\n[END OF QUERY]\n\n" |
|
|
|
if len(conversation_history) > 0: |
|
prompt += build_conversation_history_prompt(conversation_history) |
|
return prompt |
|
|
|
def generate_response(tools_input, query): |
|
try: |
|
tools = json.loads(tools_input) |
|
except json.JSONDecodeError: |
|
return "Error: Invalid JSON format for tools input." |
|
|
|
xlam_format_tools = convert_to_xlam_tool(tools) |
|
conversation_history = [] |
|
content = build_prompt(task_instruction, format_instruction, xlam_format_tools, query, conversation_history) |
|
|
|
messages = [ |
|
{'role': 'user', 'content': content} |
|
] |
|
|
|
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device) |
|
outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id) |
|
agent_action = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True) |
|
|
|
return agent_action |
|
|
|
|
|
iface = gr.Interface( |
|
fn=generate_response, |
|
inputs=[ |
|
gr.Textbox( |
|
label="Available Tools (JSON format)", |
|
lines=10, |
|
value=json.dumps([ |
|
{ |
|
"name": "get_weather", |
|
"description": "Get the current weather for a location", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"location": { |
|
"type": "string", |
|
"description": "The city and state, e.g. San Francisco, New York" |
|
}, |
|
"unit": { |
|
"type": "string", |
|
"enum": ["celsius", "fahrenheit"], |
|
"description": "The unit of temperature to return" |
|
} |
|
}, |
|
"required": ["location"] |
|
} |
|
}, |
|
{ |
|
"name": "search", |
|
"description": "Search for information on the internet", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"query": { |
|
"type": "string", |
|
"description": "The search query, e.g. 'latest news on AI'" |
|
} |
|
}, |
|
"required": ["query"] |
|
} |
|
} |
|
], indent=2) |
|
), |
|
gr.Textbox(label="User Query", lines=2, value="What's the weather like in New York in fahrenheit?") |
|
], |
|
outputs=gr.Textbox(label="Generated Response", lines=10), |
|
title="xLAM-7b-r API Request Generator", |
|
description="Enter available tools in JSON format and a user query to generate an API request or response.", |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |