form chat prompt
Browse files- model/custom_agent.py +58 -0
model/custom_agent.py
CHANGED
@@ -11,16 +11,28 @@ Dependencies:
|
|
11 |
|
12 |
Classes:
|
13 |
- CustomHfAgent: A custom class for interacting with the Hugging Face model API.
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
"""
|
15 |
|
16 |
import time
|
17 |
import requests
|
18 |
from transformers import Agent
|
19 |
from utils.logger import log_response
|
|
|
|
|
|
|
20 |
|
21 |
class CustomHfAgent(Agent):
|
22 |
"""A custom class for interacting with the Hugging Face model API."""
|
23 |
|
|
|
|
|
|
|
24 |
def __init__(self, url_endpoint, token, chat_prompt_template=None, run_prompt_template=None, additional_tools=None, input_params=None):
|
25 |
"""
|
26 |
Initialize the CustomHfAgent.
|
@@ -82,3 +94,49 @@ class CustomHfAgent(Agent):
|
|
82 |
if result.endswith(stop_seq):
|
83 |
return result[: -len(stop_seq)]
|
84 |
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
Classes:
|
13 |
- CustomHfAgent: A custom class for interacting with the Hugging Face model API.
|
14 |
+
|
15 |
+
Reasono for making this https://github.com/huggingface/transformers/issues/28217
|
16 |
+
Based on https://github.com/huggingface/transformers/blob/main/src/transformers/tools/agents.py
|
17 |
+
|
18 |
+
"return_full_text": False,
|
19 |
+
|
20 |
"""
|
21 |
|
22 |
import time
|
23 |
import requests
|
24 |
from transformers import Agent
|
25 |
from utils.logger import log_response
|
26 |
+
|
27 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
28 |
+
|
29 |
|
30 |
class CustomHfAgent(Agent):
|
31 |
"""A custom class for interacting with the Hugging Face model API."""
|
32 |
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
def __init__(self, url_endpoint, token, chat_prompt_template=None, run_prompt_template=None, additional_tools=None, input_params=None):
|
37 |
"""
|
38 |
Initialize the CustomHfAgent.
|
|
|
94 |
if result.endswith(stop_seq):
|
95 |
return result[: -len(stop_seq)]
|
96 |
return result
|
97 |
+
###
|
98 |
+
###
|
99 |
+
### https://github.com/huggingface/transformers/blob/main/src/transformers/tools/prompts.py -> run chat_template.txt
|
100 |
+
### https://huggingface.co/datasets/huggingface-tools/default-prompts/blob/main/chat_prompt_template.txt
|
101 |
+
###
|
102 |
+
def format_prompt(self, task, chat_mode=False):
|
103 |
+
|
104 |
+
checkpoint = "bigcode/starcoder"
|
105 |
+
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
106 |
+
#model = AutoModelForCausalLM.from_pretrained(checkpoint) # You may want to use bfloat16 and/or move to GPU here
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
description = "\n".join([f"- {name}: {tool.description}" for name, tool in self.toolbox.items()])
|
111 |
+
if chat_mode:
|
112 |
+
if self.chat_history is None:
|
113 |
+
prompt = self.chat_prompt_template.replace("<<all_tools>>", description)
|
114 |
+
messages = [
|
115 |
+
{
|
116 |
+
"role": "user",
|
117 |
+
"content": prompt,
|
118 |
+
}
|
119 |
+
]
|
120 |
+
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
121 |
+
else:
|
122 |
+
prompt = self.chat_history
|
123 |
+
cmp = CHAT_MESSAGE_PROMPT.replace("<<task>>", task)
|
124 |
+
messages = [
|
125 |
+
{
|
126 |
+
"role": "user",
|
127 |
+
"content": cmp,
|
128 |
+
}
|
129 |
+
]
|
130 |
+
cmp = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
131 |
+
prompt += cmp
|
132 |
+
else:
|
133 |
+
prompt = self.run_prompt_template.replace("<<all_tools>>", description)
|
134 |
+
prompt = prompt.replace("<<prompt>>", task)
|
135 |
+
messages = [
|
136 |
+
{
|
137 |
+
"role": "user",
|
138 |
+
"content": prompt,
|
139 |
+
}
|
140 |
+
]
|
141 |
+
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
142 |
+
return prompt
|