Spaces:
Runtime error
Runtime error
import json | |
from langchain.llms.base import LLM | |
from transformers import AutoTokenizer, AutoModel, AutoConfig | |
from typing import List, Optional | |
from utils import tool_config_from_file | |
class ChatGLM3(LLM): | |
max_token: int = 8192 | |
do_sample: bool = False | |
temperature: float = 0.8 | |
top_p = 0.8 | |
tokenizer: object = None | |
model: object = None | |
history: List = [] | |
tool_names: List = [] | |
has_search: bool = False | |
def __init__(self): | |
super().__init__() | |
def _llm_type(self) -> str: | |
return "ChatGLM3" | |
def load_model(self, model_name_or_path=None): | |
model_config = AutoConfig.from_pretrained( | |
model_name_or_path, | |
trust_remote_code=True | |
) | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
model_name_or_path, | |
trust_remote_code=True | |
) | |
self.model = AutoModel.from_pretrained( | |
model_name_or_path, config=model_config, trust_remote_code=True | |
).half().cuda() | |
def _tool_history(self, prompt: str): | |
ans = [] | |
tool_prompts = prompt.split( | |
"You have access to the following tools:\n\n")[1].split("\n\nUse a json blob")[0].split("\n") | |
tool_names = [tool.split(":")[0] for tool in tool_prompts] | |
self.tool_names = tool_names | |
tools_json = [] | |
for i, tool in enumerate(tool_names): | |
tool_config = tool_config_from_file(tool) | |
if tool_config: | |
tools_json.append(tool_config) | |
else: | |
ValueError( | |
f"Tool {tool} config not found! It's description is {tool_prompts[i]}" | |
) | |
ans.append({ | |
"role": "system", | |
"content": "Answer the following questions as best as you can. You have access to the following tools:", | |
"tools": tools_json | |
}) | |
query = f"""{prompt.split("Human: ")[-1].strip()}""" | |
return ans, query | |
def _extract_observation(self, prompt: str): | |
return_json = prompt.split("Observation: ")[-1].split("\nThought:")[0] | |
self.history.append({ | |
"role": "observation", | |
"content": return_json | |
}) | |
return | |
def _extract_tool(self): | |
if len(self.history[-1]["metadata"]) > 0: | |
metadata = self.history[-1]["metadata"] | |
content = self.history[-1]["content"] | |
if "tool_call" in content: | |
for tool in self.tool_names: | |
if tool in metadata: | |
input_para = content.split("='")[-1].split("'")[0] | |
action_json = { | |
"action": tool, | |
"action_input": input_para | |
} | |
self.has_search = True | |
return f""" | |
Action: | |
``` | |
{json.dumps(action_json, ensure_ascii=False)} | |
```""" | |
final_answer_json = { | |
"action": "Final Answer", | |
"action_input": self.history[-1]["content"] | |
} | |
self.has_search = False | |
return f""" | |
Action: | |
``` | |
{json.dumps(final_answer_json, ensure_ascii=False)} | |
```""" | |
def _call(self, prompt: str, history: List = [], stop: Optional[List[str]] = ["<|user|>"]): | |
print("======") | |
print(self.prompt) | |
print("======") | |
if not self.has_search: | |
self.history, query = self._tool_history(prompt) | |
else: | |
self._extract_observation(prompt) | |
query = "" | |
# print("======") | |
# print(self.history) | |
# print("======") | |
_, self.history = self.model.chat( | |
self.tokenizer, | |
query, | |
history=self.history, | |
do_sample=self.do_sample, | |
max_length=self.max_token, | |
temperature=self.temperature, | |
) | |
response = self._extract_tool() | |
history.append((prompt, response)) | |
return response | |