File size: 5,529 Bytes
e11a899
3a69f52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e11a899
3a69f52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import spaces
import json
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load model and tokenizer
model_name = "Salesforce/xLAM-7b-r"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype="auto", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Set random seed for reproducibility
torch.random.manual_seed(0)

# Task and format instructions
task_instruction = """
Based on the previous context and API request history, generate an API request or a response as an AI assistant.""".strip()

format_instruction = """
The output should be of the JSON format, which specifies a list of generated function calls. The example format is as follows, please make sure the parameter type is correct. If no function call is needed, please make 
tool_calls an empty list "[]".
```
{"thought": "the thought process, or an empty string", "tool_calls": [{"name": "api_name1", "arguments": {"argument1": "value1", "argument2": "value2"}}]}
```
""".strip()

def convert_to_xlam_tool(tools):
    if isinstance(tools, dict):
        return {
            "name": tools["name"],
            "description": tools["description"],
            "parameters": {k: v for k, v in tools["parameters"].get("properties", {}).items()}
        }
    elif isinstance(tools, list):
        return [convert_to_xlam_tool(tool) for tool in tools]
    else:
        return tools

def build_conversation_history_prompt(conversation_history: str):
    parsed_history = []
    for step_data in conversation_history:
        parsed_history.append({
            "step_id": step_data["step_id"],
            "thought": step_data["thought"],
            "tool_calls": step_data["tool_calls"],
            "next_observation": step_data["next_observation"],
            "user_input": step_data['user_input']
        })
    
    history_string = json.dumps(parsed_history)
    return f"\n[BEGIN OF HISTORY STEPS]\n{history_string}\n[END OF HISTORY STEPS]\n"

def build_prompt(task_instruction: str, format_instruction: str, tools: list, query: str, conversation_history: list):
    prompt = f"[BEGIN OF TASK INSTRUCTION]\n{task_instruction}\n[END OF TASK INSTRUCTION]\n\n"
    prompt += f"[BEGIN OF AVAILABLE TOOLS]\n{json.dumps(tools)}\n[END OF AVAILABLE TOOLS]\n\n"
    prompt += f"[BEGIN OF FORMAT INSTRUCTION]\n{format_instruction}\n[END OF FORMAT INSTRUCTION]\n\n"
    prompt += f"[BEGIN OF QUERY]\n{query}\n[END OF QUERY]\n\n"
    
    if len(conversation_history) > 0:
        prompt += build_conversation_history_prompt(conversation_history)
    return prompt

@spaces.GPU
def generate_response(tools_input, query):
    try:
        tools = json.loads(tools_input)
    except json.JSONDecodeError:
        return "Error: Invalid JSON format for tools input."

    xlam_format_tools = convert_to_xlam_tool(tools)
    conversation_history = []
    content = build_prompt(task_instruction, format_instruction, xlam_format_tools, query, conversation_history)

    messages = [
        {'role': 'user', 'content': content}
    ]

    inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
    outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
    agent_action = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)

    return agent_action

# Gradio interface
iface = gr.Interface(
    fn=generate_response,
    inputs=[
        gr.Textbox(
            label="Available Tools (JSON format)",
            lines=10,
            value=json.dumps([
                {
                    "name": "get_weather",
                    "description": "Get the current weather for a location",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "location": {
                                "type": "string",
                                "description": "The city and state, e.g. San Francisco, New York"
                            },
                            "unit": {
                                "type": "string",
                                "enum": ["celsius", "fahrenheit"],
                                "description": "The unit of temperature to return"
                            }
                        },
                        "required": ["location"]
                    }
                },
                {
                    "name": "search",
                    "description": "Search for information on the internet",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "query": {
                                "type": "string",
                                "description": "The search query, e.g. 'latest news on AI'"
                            }
                        },
                        "required": ["query"]
                    }
                }
            ], indent=2)
        ),
        gr.Textbox(label="User Query", lines=2, value="What's the weather like in New York in fahrenheit?")
    ],
    outputs=gr.Textbox(label="Generated Response", lines=10),
    title="xLAM-7b-r API Request Generator",
    description="Enter available tools in JSON format and a user query to generate an API request or response.",
)

if __name__ == "__main__":
    iface.launch()