Upload 8 files
Browse files- app.py +93 -0
- conversation.py +159 -0
- demo.py +24 -0
- demo1.py +91 -0
- demotool.py +67 -0
- llmLoader.py +55 -0
- loader.py +171 -0
- singleton.py +24 -0
app.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from langchain.agents import initialize_agent
|
3 |
+
# from langchain.llms import OpenAI
|
4 |
+
# from langchain.chat_models import ChatOpenAI
|
5 |
+
|
6 |
+
from langchain.tools import BaseTool, StructuredTool, Tool, tool
|
7 |
+
from PIL import Image
|
8 |
+
from demotool import *
|
9 |
+
from loader import *
|
10 |
+
# from llmLoader import *
|
11 |
+
import re
|
12 |
+
from gradio_tools.tools import (StableDiffusionTool, ImageCaptioningTool, StableDiffusionPromptGeneratorTool,
|
13 |
+
TextToVideoTool)
|
14 |
+
|
15 |
+
from langchain.memory import ConversationBufferMemory
|
16 |
+
# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
17 |
+
|
18 |
+
from langchain import PromptTemplate, HuggingFaceHub, LLMChain
|
19 |
+
|
20 |
+
def init_model_config():
|
21 |
+
llm = ChatLLM()
|
22 |
+
llm.model_type = 'chatglm'
|
23 |
+
llm.model_name_or_path = llm_model_dict['chatglm'][
|
24 |
+
'ChatGLM-6B-int4']
|
25 |
+
llm.load_llm()
|
26 |
+
return llm
|
27 |
+
|
28 |
+
# initialize HF LLM
|
29 |
+
# flan_t5 = HuggingFaceHub(
|
30 |
+
# repo_id="google/flan-t5-xl",
|
31 |
+
# model_kwargs={"temperature":1e-10},
|
32 |
+
# huggingfacehub_api_token="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"
|
33 |
+
# )
|
34 |
+
|
35 |
+
|
36 |
+
# llm = ChatOpenAI(openai_api_key="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",temperature=0,streaming=True,callbacks=[StreamingStdOutCallbackHandler()])
|
37 |
+
|
38 |
+
# llm = ModelLoader()
|
39 |
+
# llm.loader()
|
40 |
+
# chatLLM = ModelLoader()
|
41 |
+
# chatLLM.loader()
|
42 |
+
memory = ConversationBufferMemory(memory_key="chat_history")
|
43 |
+
|
44 |
+
|
45 |
+
# tools = [ Text2Image()]
|
46 |
+
|
47 |
+
# tools = [ Tool.from_function(
|
48 |
+
# func=search,
|
49 |
+
# name = "Search",
|
50 |
+
# description="useful for when you need to answer questions about current events"
|
51 |
+
# )]
|
52 |
+
|
53 |
+
# tools = [ Tool.from_function(
|
54 |
+
# func=optimizationProblem,
|
55 |
+
# name = "optimizationProblem",
|
56 |
+
# description=" you must use this tool when you need to Add more information"
|
57 |
+
# )]
|
58 |
+
|
59 |
+
|
60 |
+
# tools = [ StableDiffusionPromptGeneratorTool().langchain]
|
61 |
+
|
62 |
+
tools = []
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
agent = initialize_agent(tools, init_model_config(), memory=memory, agent="conversational-react-description", verbose=True)
|
68 |
+
|
69 |
+
def run_text(text, state):
|
70 |
+
# print("stat:"+text)
|
71 |
+
# res = llm_chain.run(text)
|
72 |
+
# print("res:"+res)
|
73 |
+
res = agent.run(input=(text))
|
74 |
+
response = re.sub('(image/\S*png)', lambda m: f'})*{m.group(0)}*', res)
|
75 |
+
state = state + [(text, response)]
|
76 |
+
return state,state
|
77 |
+
|
78 |
+
with gr.Blocks(css="#chatbot {overflow:auto; height:500px;}") as demo:
|
79 |
+
chatbot = gr.Chatbot(elem_id="chatbot",show_label=False)
|
80 |
+
state = gr.State([])
|
81 |
+
with gr.Row() as input_raws:
|
82 |
+
with gr.Column(scale=0.6):
|
83 |
+
txt = gr.Textbox(show_label=False).style(container=False)
|
84 |
+
with gr.Column(scale=0.20, min_width=0):
|
85 |
+
run = gr.Button("🏃♂️Run")
|
86 |
+
with gr.Column(scale=0.20, min_width=0):
|
87 |
+
clear = gr.Button("🔄Clear️")
|
88 |
+
|
89 |
+
txt.submit(run_text, [txt, state], [chatbot,state])
|
90 |
+
txt.submit(lambda: "", None, txt)
|
91 |
+
run.click(run_text, [txt, state], [chatbot,state])
|
92 |
+
|
93 |
+
demo.queue(concurrency_count=10).launch(server_name="0.0.0.0", server_port=7865)
|
conversation.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from enum import auto, Enum
|
3 |
+
from typing import List, Tuple, Any
|
4 |
+
|
5 |
+
|
6 |
+
class SeparatorStyle(Enum):
|
7 |
+
"""Different separator style."""
|
8 |
+
SINGLE = auto()
|
9 |
+
TWO = auto()
|
10 |
+
|
11 |
+
|
12 |
+
@dataclasses.dataclass
|
13 |
+
class Conversation:
|
14 |
+
"""A class that keeps all conversation history."""
|
15 |
+
system: str
|
16 |
+
roles: List[str]
|
17 |
+
messages: List[List[str]]
|
18 |
+
offset: int
|
19 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
20 |
+
sep: str = "###"
|
21 |
+
sep2: str = None
|
22 |
+
|
23 |
+
skip_next: bool = False
|
24 |
+
conv_id: Any = None
|
25 |
+
|
26 |
+
def get_prompt(self):
|
27 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
28 |
+
ret = self.system + self.sep
|
29 |
+
for role, message in self.messages:
|
30 |
+
if message:
|
31 |
+
ret += role + ": " + message + self.sep
|
32 |
+
else:
|
33 |
+
ret += role + ":"
|
34 |
+
return ret
|
35 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
36 |
+
seps = [self.sep, self.sep2]
|
37 |
+
ret = self.system + seps[0]
|
38 |
+
for i, (role, message) in enumerate(self.messages):
|
39 |
+
if message:
|
40 |
+
ret += role + ": " + message + seps[i % 2]
|
41 |
+
else:
|
42 |
+
ret += role + ":"
|
43 |
+
return ret
|
44 |
+
else:
|
45 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
46 |
+
|
47 |
+
def append_message(self, role, message):
|
48 |
+
self.messages.append([role, message])
|
49 |
+
|
50 |
+
def to_gradio_chatbot(self):
|
51 |
+
ret = []
|
52 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
53 |
+
if i % 2 == 0:
|
54 |
+
ret.append([msg, None])
|
55 |
+
else:
|
56 |
+
ret[-1][-1] = msg
|
57 |
+
return ret
|
58 |
+
|
59 |
+
def copy(self):
|
60 |
+
return Conversation(
|
61 |
+
system=self.system,
|
62 |
+
roles=self.roles,
|
63 |
+
messages=[[x, y] for x, y in self.messages],
|
64 |
+
offset=self.offset,
|
65 |
+
sep_style=self.sep_style,
|
66 |
+
sep=self.sep,
|
67 |
+
sep2=self.sep2,
|
68 |
+
conv_id=self.conv_id)
|
69 |
+
|
70 |
+
def dict(self):
|
71 |
+
return {
|
72 |
+
"system": self.system,
|
73 |
+
"roles": self.roles,
|
74 |
+
"messages": self.messages,
|
75 |
+
"offset": self.offset,
|
76 |
+
"sep": self.sep,
|
77 |
+
"sep2": self.sep2,
|
78 |
+
"conv_id": self.conv_id,
|
79 |
+
}
|
80 |
+
|
81 |
+
|
82 |
+
conv_v1 = Conversation(
|
83 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
84 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
85 |
+
roles=("Human", "Assistant"),
|
86 |
+
messages=(
|
87 |
+
("Human", "Give three tips for staying healthy."),
|
88 |
+
("Assistant",
|
89 |
+
"Sure, here are three tips for staying healthy:\n"
|
90 |
+
"1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. "
|
91 |
+
"It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, "
|
92 |
+
"and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or "
|
93 |
+
"75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening "
|
94 |
+
"activities at least two days per week.\n"
|
95 |
+
"2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, "
|
96 |
+
"vegetables, whole grains, lean proteins, and healthy fats can help support "
|
97 |
+
"your overall health. Try to limit your intake of processed and high-sugar foods, "
|
98 |
+
"and aim to drink plenty of water throughout the day.\n"
|
99 |
+
"3. Get enough sleep: Getting enough quality sleep is essential for your physical "
|
100 |
+
"and mental health. Adults should aim for seven to nine hours of sleep per night. "
|
101 |
+
"Establish a regular sleep schedule and try to create a relaxing bedtime routine to "
|
102 |
+
"help improve the quality of your sleep.")
|
103 |
+
),
|
104 |
+
offset=2,
|
105 |
+
sep_style=SeparatorStyle.SINGLE,
|
106 |
+
sep="###",
|
107 |
+
)
|
108 |
+
|
109 |
+
conv_v1_2 = Conversation(
|
110 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
111 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
112 |
+
roles=("Human", "Assistant"),
|
113 |
+
messages=(
|
114 |
+
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
|
115 |
+
("Assistant",
|
116 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
117 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
118 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
119 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
120 |
+
"renewable and non-renewable energy sources:\n"
|
121 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
122 |
+
"energy sources are finite and will eventually run out.\n"
|
123 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
124 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
125 |
+
"and other negative effects.\n"
|
126 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
127 |
+
"have lower operational costs than non-renewable sources.\n"
|
128 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
129 |
+
"locations than non-renewable sources.\n"
|
130 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
131 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
132 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
133 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
|
134 |
+
),
|
135 |
+
offset=2,
|
136 |
+
sep_style=SeparatorStyle.SINGLE,
|
137 |
+
sep="###",
|
138 |
+
)
|
139 |
+
|
140 |
+
conv_bair_v1 = Conversation(
|
141 |
+
system="BEGINNING OF CONVERSATION:",
|
142 |
+
roles=("USER", "GPT"),
|
143 |
+
messages=(),
|
144 |
+
offset=0,
|
145 |
+
sep_style=SeparatorStyle.TWO,
|
146 |
+
sep=" ",
|
147 |
+
sep2="</s>",
|
148 |
+
)
|
149 |
+
|
150 |
+
|
151 |
+
default_conversation = conv_v1_2
|
152 |
+
conv_templates = {
|
153 |
+
"v1": conv_v1_2,
|
154 |
+
"bair_v1": conv_bair_v1,
|
155 |
+
}
|
156 |
+
|
157 |
+
|
158 |
+
if __name__ == "__main__":
|
159 |
+
print(default_conversation.get_prompt())
|
demo.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain import PromptTemplate, HuggingFaceHub, LLMChain
|
2 |
+
|
3 |
+
|
4 |
+
# initialize HF LLM
|
5 |
+
flan_t5 = HuggingFaceHub(
|
6 |
+
repo_id="google/flan-t5-xl",
|
7 |
+
model_kwargs={"temperature":1e-10},
|
8 |
+
huggingfacehub_api_token="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"
|
9 |
+
)
|
10 |
+
|
11 |
+
# build prompt template for simple question-answering
|
12 |
+
template = """Question: {question}
|
13 |
+
|
14 |
+
Answer: """
|
15 |
+
prompt = PromptTemplate(template=template, input_variables=["question"])
|
16 |
+
|
17 |
+
llm_chain = LLMChain(
|
18 |
+
prompt=prompt,
|
19 |
+
llm=flan_t5
|
20 |
+
)
|
21 |
+
|
22 |
+
question = "Which NFL team won the Super Bowl in the 2010 season?"
|
23 |
+
|
24 |
+
print(llm_chain.run(question))
|
demo1.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from langchain.agents import initialize_agent
|
3 |
+
# from langchain.llms import OpenAI
|
4 |
+
# from langchain.chat_models import ChatOpenAI
|
5 |
+
|
6 |
+
from langchain.tools import BaseTool, StructuredTool, Tool, tool
|
7 |
+
from PIL import Image
|
8 |
+
from demotool import *
|
9 |
+
from loader import *
|
10 |
+
from llmLoader import *
|
11 |
+
import re
|
12 |
+
from gradio_tools.tools import (StableDiffusionTool, ImageCaptioningTool, StableDiffusionPromptGeneratorTool,
|
13 |
+
TextToVideoTool)
|
14 |
+
|
15 |
+
from langchain.memory import ConversationBufferMemory
|
16 |
+
# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
17 |
+
|
18 |
+
from langchain import PromptTemplate, HuggingFaceHub, LLMChain
|
19 |
+
|
20 |
+
# def init_model_config():
|
21 |
+
# llm = ChatLLM()
|
22 |
+
# llm.model_type = 'chatglm'
|
23 |
+
# llm.model_name_or_path = llm_model_dict['chatglm'][
|
24 |
+
# 'ChatGLM-6B-int4']
|
25 |
+
# llm.load_llm()
|
26 |
+
# return llm
|
27 |
+
|
28 |
+
# initialize HF LLM
|
29 |
+
# flan_t5 = HuggingFaceHub(
|
30 |
+
# repo_id="google/flan-t5-xl",
|
31 |
+
# model_kwargs={"temperature":1e-10},
|
32 |
+
# huggingfacehub_api_token="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"
|
33 |
+
# )
|
34 |
+
|
35 |
+
|
36 |
+
# llm = ChatOpenAI(openai_api_key="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",temperature=0,streaming=True,callbacks=[StreamingStdOutCallbackHandler()])
|
37 |
+
|
38 |
+
chatLLMm = ModelLoader()
|
39 |
+
chatLLMm.load_model()
|
40 |
+
memory = ConversationBufferMemory(memory_key="chat_history")
|
41 |
+
|
42 |
+
|
43 |
+
# tools = [ Text2Image()]
|
44 |
+
|
45 |
+
# tools = [ Tool.from_function(
|
46 |
+
# func=search,
|
47 |
+
# name = "Search",
|
48 |
+
# description="useful for when you need to answer questions about current events"
|
49 |
+
# )]
|
50 |
+
|
51 |
+
# tools = [ Tool.from_function(
|
52 |
+
# func=optimizationProblem,
|
53 |
+
# name = "optimizationProblem",
|
54 |
+
# description=" you must use this tool when you need to Add more information"
|
55 |
+
# )]
|
56 |
+
|
57 |
+
|
58 |
+
# tools = [ StableDiffusionPromptGeneratorTool().langchain]
|
59 |
+
|
60 |
+
tools = []
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
agent = initialize_agent(tools, chatLLMm, memory=memory, agent="conversational-react-description", verbose=True)
|
66 |
+
|
67 |
+
def run_text(text, state):
|
68 |
+
# print("stat:"+text)
|
69 |
+
# res = llm_chain.run(text)
|
70 |
+
# print("res:"+res)
|
71 |
+
res = agent.run(input=(text))
|
72 |
+
response = re.sub('(image/\S*png)', lambda m: f'})*{m.group(0)}*', res)
|
73 |
+
state = state + [(text, response)]
|
74 |
+
return state,state
|
75 |
+
|
76 |
+
with gr.Blocks(css="#chatbot {overflow:auto; height:500px;}") as demo:
|
77 |
+
chatbot = gr.Chatbot(elem_id="chatbot",show_label=False)
|
78 |
+
state = gr.State([])
|
79 |
+
with gr.Row() as input_raws:
|
80 |
+
with gr.Column(scale=0.6):
|
81 |
+
txt = gr.Textbox(show_label=False).style(container=False)
|
82 |
+
with gr.Column(scale=0.20, min_width=0):
|
83 |
+
run = gr.Button("🏃♂️Run")
|
84 |
+
with gr.Column(scale=0.20, min_width=0):
|
85 |
+
clear = gr.Button("🔄Clear️")
|
86 |
+
|
87 |
+
txt.submit(run_text, [txt, state], [chatbot,state])
|
88 |
+
txt.submit(lambda: "", None, txt)
|
89 |
+
run.click(run_text, [txt, state], [chatbot,state])
|
90 |
+
|
91 |
+
demo.queue(concurrency_count=10).launch(server_name="0.0.0.0", server_port=7865)
|
demotool.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.tools import BaseTool, StructuredTool, Tool, tool
|
2 |
+
from typing import Optional, Type
|
3 |
+
from langchain.callbacks.manager import AsyncCallbackManagerForToolRun, CallbackManagerForToolRun
|
4 |
+
import requests
|
5 |
+
import base64
|
6 |
+
import os
|
7 |
+
import uuid
|
8 |
+
from PIL import Image, ImageOps, ImageDraw, ImageFont
|
9 |
+
|
10 |
+
def optimizationProblem(query):
|
11 |
+
query = query +" What's the date today?"
|
12 |
+
return query
|
13 |
+
|
14 |
+
class CustomWeatherTool(BaseTool):
|
15 |
+
name = "weather"
|
16 |
+
description = "useful for when the input to this tool should be city"
|
17 |
+
|
18 |
+
def _run(self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> str:
|
19 |
+
|
20 |
+
return "The weather in "+query
|
21 |
+
async def _arun(self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None) -> str:
|
22 |
+
"""Use the tool asynchronously."""
|
23 |
+
raise NotImplementedError("custom_search does not support async")
|
24 |
+
|
25 |
+
|
26 |
+
class Text2Image(BaseTool):
|
27 |
+
name = "Generate Image From User Input Text"
|
28 |
+
description ="useful when you want to generate an image from a user input text and save it to a file. like: generate an image of an object or something, or generate an image that includes some objects. The input to this tool should be a string, representing the text used to generate image. "
|
29 |
+
return_direct=True
|
30 |
+
def _run(self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> str:
|
31 |
+
url = "http://region-9.seetacloud.com:39487/sdapi/v1/txt2img"
|
32 |
+
body = {
|
33 |
+
"negative_prompt": "",
|
34 |
+
"width": "900",
|
35 |
+
"prompt": query,
|
36 |
+
"steps": "30",
|
37 |
+
"cfg_scale": "8",
|
38 |
+
"height": "900"
|
39 |
+
}
|
40 |
+
|
41 |
+
try:
|
42 |
+
result = requests.post(url, json=body, stream=True)
|
43 |
+
result.raise_for_status() # Raise an exception if request was not successful
|
44 |
+
response_data = result.json()
|
45 |
+
images_json = response_data["images"]
|
46 |
+
if len(images_json) > 0:
|
47 |
+
image_data = images_json[0].split(",", 1)[0]
|
48 |
+
image_bytes = base64.b64decode(image_data)
|
49 |
+
image_filename = os.path.join('image', f"{str(uuid.uuid4())[:8]}.png")
|
50 |
+
with open(image_filename, "wb") as file:
|
51 |
+
file.write(image_bytes)
|
52 |
+
|
53 |
+
except requests.exceptions.RequestException as e:
|
54 |
+
print("An error occurred:", e)
|
55 |
+
|
56 |
+
return image_filename
|
57 |
+
async def _arun(self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None) -> str:
|
58 |
+
"""Use the tool asynchronously."""
|
59 |
+
raise NotImplementedError("custom_search does not support async")
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
@tool("search", return_direct=True)
|
65 |
+
def search_api(query: str) -> str:
|
66 |
+
"""Searches the API for the query."""
|
67 |
+
return f"Results for query {query}"
|
llmLoader.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional
|
2 |
+
from langchain.llms.base import LLM
|
3 |
+
import torch
|
4 |
+
from transformers import AutoModel, AutoTokenizer
|
5 |
+
from langchain.llms.utils import enforce_stop_tokens
|
6 |
+
from fastchat.conversation import (compute_skip_echo_len,
|
7 |
+
get_default_conv_template)
|
8 |
+
|
9 |
+
|
10 |
+
class ModelLoader(LLM):
|
11 |
+
tokenizer: object = None
|
12 |
+
model: object = None
|
13 |
+
max_token: int = 10000
|
14 |
+
temperature: float = 0.1
|
15 |
+
top_p = 0.9
|
16 |
+
history = []
|
17 |
+
|
18 |
+
def __init__(self):
|
19 |
+
super().__init__()
|
20 |
+
|
21 |
+
@property
|
22 |
+
def _llm_type(self) -> str:
|
23 |
+
return "ChatLLM"
|
24 |
+
|
25 |
+
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
26 |
+
conv = get_default_conv_template("/DATA/gpt/lang/model_cache/THUDM/chatglm-6b-int8").copy()
|
27 |
+
conv.append_message(conv.roles[0], prompt)
|
28 |
+
conv.append_message(conv.roles[1], None)
|
29 |
+
prompt = conv.get_prompt()
|
30 |
+
inputs = self.tokenizer([prompt])
|
31 |
+
output_ids = self.model.generate(
|
32 |
+
torch.as_tensor(inputs.input_ids).cuda(),
|
33 |
+
do_sample=True,
|
34 |
+
temperature=self.temperature,
|
35 |
+
max_new_tokens=self.max_token,
|
36 |
+
)
|
37 |
+
outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
|
38 |
+
skip_echo_len = compute_skip_echo_len("/DATA/gpt/lang/model_cache/THUDM/chatglm-6b-int8", conv, prompt)
|
39 |
+
response = outputs[skip_echo_len:]
|
40 |
+
if stop is not None:
|
41 |
+
response = enforce_stop_tokens(response, stop)
|
42 |
+
self.history = [[None, response]]
|
43 |
+
|
44 |
+
return response
|
45 |
+
|
46 |
+
|
47 |
+
def load_model(self, model_name_or_path: str = "/DATA/gpt/lang/model_cache/THUDM/chatglm-6b-int8"):
|
48 |
+
|
49 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
50 |
+
"/DATA/gpt/mingpt-7b/MiniGPT-4-LLaMA-7B",
|
51 |
+
trust_remote_code=True
|
52 |
+
)
|
53 |
+
self.model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
|
54 |
+
self.model = self.model.eval()
|
55 |
+
|
loader.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Dict, List, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from fastchat.conversation import (compute_skip_echo_len,
|
6 |
+
get_default_conv_template)
|
7 |
+
from fastchat.serve.inference import load_model as load_fastchat_model
|
8 |
+
from langchain.llms.base import LLM
|
9 |
+
from langchain.llms.utils import enforce_stop_tokens
|
10 |
+
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
|
11 |
+
|
12 |
+
|
13 |
+
MODEL_CACHE_PATH = os.path.join(os.path.dirname(__file__), 'model_cache')
|
14 |
+
|
15 |
+
llm_model_dict = {
|
16 |
+
"chatglm": {
|
17 |
+
"ChatGLM-6B": "THUDM/chatglm-6b",
|
18 |
+
"ChatGLM-6B-int4": "THUDM/chatglm-6b-int4",
|
19 |
+
"ChatGLM-6B-int8": "THUDM/chatglm-6b-int8",
|
20 |
+
"ChatGLM-6b-int4-qe": "THUDM/chatglm-6b-int4-qe"
|
21 |
+
},
|
22 |
+
"belle": {
|
23 |
+
"BELLE-LLaMA-Local": "/pretrainmodel/belle",
|
24 |
+
},
|
25 |
+
"vicuna": {
|
26 |
+
"Vicuna-Local": "/pretrainmodel/vicuna",
|
27 |
+
}
|
28 |
+
}
|
29 |
+
|
30 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
31 |
+
|
32 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
33 |
+
DEVICE_ID = "0"
|
34 |
+
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
|
35 |
+
|
36 |
+
|
37 |
+
def torch_gc():
|
38 |
+
if torch.cuda.is_available():
|
39 |
+
with torch.cuda.device(CUDA_DEVICE):
|
40 |
+
torch.cuda.empty_cache()
|
41 |
+
torch.cuda.ipc_collect()
|
42 |
+
|
43 |
+
|
44 |
+
def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
|
45 |
+
num_trans_layers = 28
|
46 |
+
per_gpu_layers = 30 / num_gpus
|
47 |
+
|
48 |
+
device_map = {
|
49 |
+
'transformer.word_embeddings': 0,
|
50 |
+
'transformer.final_layernorm': 0,
|
51 |
+
'lm_head': 0
|
52 |
+
}
|
53 |
+
|
54 |
+
used = 2
|
55 |
+
gpu_target = 0
|
56 |
+
for i in range(num_trans_layers):
|
57 |
+
if used >= per_gpu_layers:
|
58 |
+
gpu_target += 1
|
59 |
+
used = 0
|
60 |
+
assert gpu_target < num_gpus
|
61 |
+
device_map[f'transformer.layers.{i}'] = gpu_target
|
62 |
+
used += 1
|
63 |
+
|
64 |
+
return device_map
|
65 |
+
|
66 |
+
|
67 |
+
class ChatLLM(LLM):
|
68 |
+
max_token: int = 10000
|
69 |
+
temperature: float = 0.1
|
70 |
+
top_p = 0.9
|
71 |
+
history = []
|
72 |
+
model_type: str = "chatglm"
|
73 |
+
model_name_or_path: str = "ChatGLM-6B-int4",
|
74 |
+
tokenizer: object = None
|
75 |
+
model: object = None
|
76 |
+
|
77 |
+
def __init__(self):
|
78 |
+
super().__init__()
|
79 |
+
|
80 |
+
@property
|
81 |
+
def _llm_type(self) -> str:
|
82 |
+
return "ChatLLM"
|
83 |
+
|
84 |
+
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
85 |
+
|
86 |
+
if self.model_type == 'vicuna':
|
87 |
+
conv = get_default_conv_template(self.model_name_or_path).copy()
|
88 |
+
conv.append_message(conv.roles[0], prompt)
|
89 |
+
conv.append_message(conv.roles[1], None)
|
90 |
+
prompt = conv.get_prompt()
|
91 |
+
inputs = self.tokenizer([prompt])
|
92 |
+
output_ids = self.model.generate(
|
93 |
+
torch.as_tensor(inputs.input_ids).cuda(),
|
94 |
+
do_sample=True,
|
95 |
+
temperature=self.temperature,
|
96 |
+
max_new_tokens=self.max_token,
|
97 |
+
)
|
98 |
+
outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
|
99 |
+
skip_echo_len = compute_skip_echo_len(self.model_name_or_path, conv, prompt)
|
100 |
+
response = outputs[skip_echo_len:]
|
101 |
+
torch_gc()
|
102 |
+
if stop is not None:
|
103 |
+
response = enforce_stop_tokens(response, stop)
|
104 |
+
self.history = [[None, response]]
|
105 |
+
|
106 |
+
elif self.model_type == 'belle':
|
107 |
+
prompt = "Human: "+ prompt +" \n\nAssistant: "
|
108 |
+
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(DEVICE)
|
109 |
+
generate_ids = self.model.generate(input_ids, max_new_tokens=self.max_token, do_sample = True, top_k = 30, top_p = self.top_p, temperature = self.temperature, repetition_penalty=1., eos_token_id=2, bos_token_id=1, pad_token_id=0)
|
110 |
+
output = self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
111 |
+
response = output[len(prompt)+1:]
|
112 |
+
torch_gc()
|
113 |
+
if stop is not None:
|
114 |
+
response = enforce_stop_tokens(response, stop)
|
115 |
+
self.history = [[None, response]]
|
116 |
+
|
117 |
+
elif self.model_type == 'chatglm':
|
118 |
+
response, _ = self.model.chat(
|
119 |
+
self.tokenizer,
|
120 |
+
prompt,
|
121 |
+
history=self.history,
|
122 |
+
max_length=self.max_token,
|
123 |
+
temperature=self.temperature,
|
124 |
+
)
|
125 |
+
torch_gc()
|
126 |
+
if stop is not None:
|
127 |
+
response = enforce_stop_tokens(response, stop)
|
128 |
+
self.history = self.history + [[None, response]]
|
129 |
+
|
130 |
+
return response
|
131 |
+
|
132 |
+
|
133 |
+
def load_llm(self,
|
134 |
+
llm_device=DEVICE,
|
135 |
+
num_gpus=torch.cuda.device_count(),
|
136 |
+
device_map: Optional[Dict[str, int]] = None,
|
137 |
+
**kwargs):
|
138 |
+
if 'chatglm' in self.model_name_or_path.lower():
|
139 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path,
|
140 |
+
trust_remote_code=True, cache_dir=os.path.join(MODEL_CACHE_PATH, self.model_name_or_path))
|
141 |
+
if torch.cuda.is_available() and llm_device.lower().startswith("cuda"):
|
142 |
+
|
143 |
+
num_gpus = torch.cuda.device_count()
|
144 |
+
if num_gpus < 2 and device_map is None:
|
145 |
+
self.model = (AutoModel.from_pretrained(
|
146 |
+
self.model_name_or_path, trust_remote_code=True, cache_dir=os.path.join(MODEL_CACHE_PATH, self.model_name_or_path),
|
147 |
+
**kwargs).half().cuda())
|
148 |
+
else:
|
149 |
+
from accelerate import dispatch_model
|
150 |
+
|
151 |
+
model = AutoModel.from_pretrained(self.model_name_or_path,
|
152 |
+
trust_remote_code=True, cache_dir=os.path.join(MODEL_CACHE_PATH, self.model_name_or_path),
|
153 |
+
**kwargs).half()
|
154 |
+
|
155 |
+
if device_map is None:
|
156 |
+
device_map = auto_configure_device_map(num_gpus)
|
157 |
+
|
158 |
+
self.model = dispatch_model(model, device_map=device_map)
|
159 |
+
else:
|
160 |
+
self.model = (AutoModel.from_pretrained(
|
161 |
+
self.model_name_or_path,
|
162 |
+
trust_remote_code=True, cache_dir=os.path.join(MODEL_CACHE_PATH, self.model_name_or_path)).float().to(llm_device))
|
163 |
+
self.model = self.model.eval()
|
164 |
+
|
165 |
+
else:
|
166 |
+
self.model, self.tokenizer = load_fastchat_model(
|
167 |
+
model_path = self.model_name_or_path,
|
168 |
+
device = llm_device,
|
169 |
+
num_gpus = num_gpus
|
170 |
+
)
|
171 |
+
|
singleton.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
"""The singleton metaclass for ensuring only one instance of a class."""
|
5 |
+
import abc
|
6 |
+
from typing import Any
|
7 |
+
|
8 |
+
|
9 |
+
class Singleton(abc.ABCMeta, type):
|
10 |
+
"""Singleton metaclass for ensuring only one instance of a class"""
|
11 |
+
|
12 |
+
_instances = {}
|
13 |
+
|
14 |
+
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
|
15 |
+
"""Call method for the singleton metaclass"""
|
16 |
+
if cls not in cls._instances:
|
17 |
+
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
|
18 |
+
return cls._instances[cls]
|
19 |
+
|
20 |
+
|
21 |
+
class AbstractSingleton(abc.ABC, metaclass=Singleton):
|
22 |
+
"""Abstract singleton class for ensuring only one instance of a class"""
|
23 |
+
|
24 |
+
pass
|