zxcgqq commited on
Commit
b5e593e
·
1 Parent(s): cedfb75

Upload 8 files

Browse files
Files changed (8) hide show
  1. app.py +93 -0
  2. conversation.py +159 -0
  3. demo.py +24 -0
  4. demo1.py +91 -0
  5. demotool.py +67 -0
  6. llmLoader.py +55 -0
  7. loader.py +171 -0
  8. singleton.py +24 -0
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from langchain.agents import initialize_agent
3
+ # from langchain.llms import OpenAI
4
+ # from langchain.chat_models import ChatOpenAI
5
+
6
+ from langchain.tools import BaseTool, StructuredTool, Tool, tool
7
+ from PIL import Image
8
+ from demotool import *
9
+ from loader import *
10
+ # from llmLoader import *
11
+ import re
12
+ from gradio_tools.tools import (StableDiffusionTool, ImageCaptioningTool, StableDiffusionPromptGeneratorTool,
13
+ TextToVideoTool)
14
+
15
+ from langchain.memory import ConversationBufferMemory
16
+ # from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
17
+
18
+ from langchain import PromptTemplate, HuggingFaceHub, LLMChain
19
+
20
+ def init_model_config():
21
+ llm = ChatLLM()
22
+ llm.model_type = 'chatglm'
23
+ llm.model_name_or_path = llm_model_dict['chatglm'][
24
+ 'ChatGLM-6B-int4']
25
+ llm.load_llm()
26
+ return llm
27
+
28
+ # initialize HF LLM
29
+ # flan_t5 = HuggingFaceHub(
30
+ # repo_id="google/flan-t5-xl",
31
+ # model_kwargs={"temperature":1e-10},
32
+ # huggingfacehub_api_token="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"
33
+ # )
34
+
35
+
36
+ # llm = ChatOpenAI(openai_api_key="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",temperature=0,streaming=True,callbacks=[StreamingStdOutCallbackHandler()])
37
+
38
+ # llm = ModelLoader()
39
+ # llm.loader()
40
+ # chatLLM = ModelLoader()
41
+ # chatLLM.loader()
42
+ memory = ConversationBufferMemory(memory_key="chat_history")
43
+
44
+
45
+ # tools = [ Text2Image()]
46
+
47
+ # tools = [ Tool.from_function(
48
+ # func=search,
49
+ # name = "Search",
50
+ # description="useful for when you need to answer questions about current events"
51
+ # )]
52
+
53
+ # tools = [ Tool.from_function(
54
+ # func=optimizationProblem,
55
+ # name = "optimizationProblem",
56
+ # description=" you must use this tool when you need to Add more information"
57
+ # )]
58
+
59
+
60
+ # tools = [ StableDiffusionPromptGeneratorTool().langchain]
61
+
62
+ tools = []
63
+
64
+
65
+
66
+
67
+ agent = initialize_agent(tools, init_model_config(), memory=memory, agent="conversational-react-description", verbose=True)
68
+
69
+ def run_text(text, state):
70
+ # print("stat:"+text)
71
+ # res = llm_chain.run(text)
72
+ # print("res:"+res)
73
+ res = agent.run(input=(text))
74
+ response = re.sub('(image/\S*png)', lambda m: f'![](/file={m.group(0)})*{m.group(0)}*', res)
75
+ state = state + [(text, response)]
76
+ return state,state
77
+
78
+ with gr.Blocks(css="#chatbot {overflow:auto; height:500px;}") as demo:
79
+ chatbot = gr.Chatbot(elem_id="chatbot",show_label=False)
80
+ state = gr.State([])
81
+ with gr.Row() as input_raws:
82
+ with gr.Column(scale=0.6):
83
+ txt = gr.Textbox(show_label=False).style(container=False)
84
+ with gr.Column(scale=0.20, min_width=0):
85
+ run = gr.Button("🏃‍♂️Run")
86
+ with gr.Column(scale=0.20, min_width=0):
87
+ clear = gr.Button("🔄Clear️")
88
+
89
+ txt.submit(run_text, [txt, state], [chatbot,state])
90
+ txt.submit(lambda: "", None, txt)
91
+ run.click(run_text, [txt, state], [chatbot,state])
92
+
93
+ demo.queue(concurrency_count=10).launch(server_name="0.0.0.0", server_port=7865)
conversation.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple, Any
4
+
5
+
6
+ class SeparatorStyle(Enum):
7
+ """Different separator style."""
8
+ SINGLE = auto()
9
+ TWO = auto()
10
+
11
+
12
+ @dataclasses.dataclass
13
+ class Conversation:
14
+ """A class that keeps all conversation history."""
15
+ system: str
16
+ roles: List[str]
17
+ messages: List[List[str]]
18
+ offset: int
19
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
20
+ sep: str = "###"
21
+ sep2: str = None
22
+
23
+ skip_next: bool = False
24
+ conv_id: Any = None
25
+
26
+ def get_prompt(self):
27
+ if self.sep_style == SeparatorStyle.SINGLE:
28
+ ret = self.system + self.sep
29
+ for role, message in self.messages:
30
+ if message:
31
+ ret += role + ": " + message + self.sep
32
+ else:
33
+ ret += role + ":"
34
+ return ret
35
+ elif self.sep_style == SeparatorStyle.TWO:
36
+ seps = [self.sep, self.sep2]
37
+ ret = self.system + seps[0]
38
+ for i, (role, message) in enumerate(self.messages):
39
+ if message:
40
+ ret += role + ": " + message + seps[i % 2]
41
+ else:
42
+ ret += role + ":"
43
+ return ret
44
+ else:
45
+ raise ValueError(f"Invalid style: {self.sep_style}")
46
+
47
+ def append_message(self, role, message):
48
+ self.messages.append([role, message])
49
+
50
+ def to_gradio_chatbot(self):
51
+ ret = []
52
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
53
+ if i % 2 == 0:
54
+ ret.append([msg, None])
55
+ else:
56
+ ret[-1][-1] = msg
57
+ return ret
58
+
59
+ def copy(self):
60
+ return Conversation(
61
+ system=self.system,
62
+ roles=self.roles,
63
+ messages=[[x, y] for x, y in self.messages],
64
+ offset=self.offset,
65
+ sep_style=self.sep_style,
66
+ sep=self.sep,
67
+ sep2=self.sep2,
68
+ conv_id=self.conv_id)
69
+
70
+ def dict(self):
71
+ return {
72
+ "system": self.system,
73
+ "roles": self.roles,
74
+ "messages": self.messages,
75
+ "offset": self.offset,
76
+ "sep": self.sep,
77
+ "sep2": self.sep2,
78
+ "conv_id": self.conv_id,
79
+ }
80
+
81
+
82
+ conv_v1 = Conversation(
83
+ system="A chat between a curious human and an artificial intelligence assistant. "
84
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
85
+ roles=("Human", "Assistant"),
86
+ messages=(
87
+ ("Human", "Give three tips for staying healthy."),
88
+ ("Assistant",
89
+ "Sure, here are three tips for staying healthy:\n"
90
+ "1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. "
91
+ "It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, "
92
+ "and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or "
93
+ "75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening "
94
+ "activities at least two days per week.\n"
95
+ "2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, "
96
+ "vegetables, whole grains, lean proteins, and healthy fats can help support "
97
+ "your overall health. Try to limit your intake of processed and high-sugar foods, "
98
+ "and aim to drink plenty of water throughout the day.\n"
99
+ "3. Get enough sleep: Getting enough quality sleep is essential for your physical "
100
+ "and mental health. Adults should aim for seven to nine hours of sleep per night. "
101
+ "Establish a regular sleep schedule and try to create a relaxing bedtime routine to "
102
+ "help improve the quality of your sleep.")
103
+ ),
104
+ offset=2,
105
+ sep_style=SeparatorStyle.SINGLE,
106
+ sep="###",
107
+ )
108
+
109
+ conv_v1_2 = Conversation(
110
+ system="A chat between a curious human and an artificial intelligence assistant. "
111
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
112
+ roles=("Human", "Assistant"),
113
+ messages=(
114
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
115
+ ("Assistant",
116
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
117
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
118
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
119
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
120
+ "renewable and non-renewable energy sources:\n"
121
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
122
+ "energy sources are finite and will eventually run out.\n"
123
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
124
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
125
+ "and other negative effects.\n"
126
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
127
+ "have lower operational costs than non-renewable sources.\n"
128
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
129
+ "locations than non-renewable sources.\n"
130
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
131
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
132
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
133
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
134
+ ),
135
+ offset=2,
136
+ sep_style=SeparatorStyle.SINGLE,
137
+ sep="###",
138
+ )
139
+
140
+ conv_bair_v1 = Conversation(
141
+ system="BEGINNING OF CONVERSATION:",
142
+ roles=("USER", "GPT"),
143
+ messages=(),
144
+ offset=0,
145
+ sep_style=SeparatorStyle.TWO,
146
+ sep=" ",
147
+ sep2="</s>",
148
+ )
149
+
150
+
151
+ default_conversation = conv_v1_2
152
+ conv_templates = {
153
+ "v1": conv_v1_2,
154
+ "bair_v1": conv_bair_v1,
155
+ }
156
+
157
+
158
+ if __name__ == "__main__":
159
+ print(default_conversation.get_prompt())
demo.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain import PromptTemplate, HuggingFaceHub, LLMChain
2
+
3
+
4
+ # initialize HF LLM
5
+ flan_t5 = HuggingFaceHub(
6
+ repo_id="google/flan-t5-xl",
7
+ model_kwargs={"temperature":1e-10},
8
+ huggingfacehub_api_token="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"
9
+ )
10
+
11
+ # build prompt template for simple question-answering
12
+ template = """Question: {question}
13
+
14
+ Answer: """
15
+ prompt = PromptTemplate(template=template, input_variables=["question"])
16
+
17
+ llm_chain = LLMChain(
18
+ prompt=prompt,
19
+ llm=flan_t5
20
+ )
21
+
22
+ question = "Which NFL team won the Super Bowl in the 2010 season?"
23
+
24
+ print(llm_chain.run(question))
demo1.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from langchain.agents import initialize_agent
3
+ # from langchain.llms import OpenAI
4
+ # from langchain.chat_models import ChatOpenAI
5
+
6
+ from langchain.tools import BaseTool, StructuredTool, Tool, tool
7
+ from PIL import Image
8
+ from demotool import *
9
+ from loader import *
10
+ from llmLoader import *
11
+ import re
12
+ from gradio_tools.tools import (StableDiffusionTool, ImageCaptioningTool, StableDiffusionPromptGeneratorTool,
13
+ TextToVideoTool)
14
+
15
+ from langchain.memory import ConversationBufferMemory
16
+ # from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
17
+
18
+ from langchain import PromptTemplate, HuggingFaceHub, LLMChain
19
+
20
+ # def init_model_config():
21
+ # llm = ChatLLM()
22
+ # llm.model_type = 'chatglm'
23
+ # llm.model_name_or_path = llm_model_dict['chatglm'][
24
+ # 'ChatGLM-6B-int4']
25
+ # llm.load_llm()
26
+ # return llm
27
+
28
+ # initialize HF LLM
29
+ # flan_t5 = HuggingFaceHub(
30
+ # repo_id="google/flan-t5-xl",
31
+ # model_kwargs={"temperature":1e-10},
32
+ # huggingfacehub_api_token="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"
33
+ # )
34
+
35
+
36
+ # llm = ChatOpenAI(openai_api_key="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",temperature=0,streaming=True,callbacks=[StreamingStdOutCallbackHandler()])
37
+
38
+ chatLLMm = ModelLoader()
39
+ chatLLMm.load_model()
40
+ memory = ConversationBufferMemory(memory_key="chat_history")
41
+
42
+
43
+ # tools = [ Text2Image()]
44
+
45
+ # tools = [ Tool.from_function(
46
+ # func=search,
47
+ # name = "Search",
48
+ # description="useful for when you need to answer questions about current events"
49
+ # )]
50
+
51
+ # tools = [ Tool.from_function(
52
+ # func=optimizationProblem,
53
+ # name = "optimizationProblem",
54
+ # description=" you must use this tool when you need to Add more information"
55
+ # )]
56
+
57
+
58
+ # tools = [ StableDiffusionPromptGeneratorTool().langchain]
59
+
60
+ tools = []
61
+
62
+
63
+
64
+
65
+ agent = initialize_agent(tools, chatLLMm, memory=memory, agent="conversational-react-description", verbose=True)
66
+
67
+ def run_text(text, state):
68
+ # print("stat:"+text)
69
+ # res = llm_chain.run(text)
70
+ # print("res:"+res)
71
+ res = agent.run(input=(text))
72
+ response = re.sub('(image/\S*png)', lambda m: f'![](/file={m.group(0)})*{m.group(0)}*', res)
73
+ state = state + [(text, response)]
74
+ return state,state
75
+
76
+ with gr.Blocks(css="#chatbot {overflow:auto; height:500px;}") as demo:
77
+ chatbot = gr.Chatbot(elem_id="chatbot",show_label=False)
78
+ state = gr.State([])
79
+ with gr.Row() as input_raws:
80
+ with gr.Column(scale=0.6):
81
+ txt = gr.Textbox(show_label=False).style(container=False)
82
+ with gr.Column(scale=0.20, min_width=0):
83
+ run = gr.Button("🏃‍♂️Run")
84
+ with gr.Column(scale=0.20, min_width=0):
85
+ clear = gr.Button("🔄Clear️")
86
+
87
+ txt.submit(run_text, [txt, state], [chatbot,state])
88
+ txt.submit(lambda: "", None, txt)
89
+ run.click(run_text, [txt, state], [chatbot,state])
90
+
91
+ demo.queue(concurrency_count=10).launch(server_name="0.0.0.0", server_port=7865)
demotool.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.tools import BaseTool, StructuredTool, Tool, tool
2
+ from typing import Optional, Type
3
+ from langchain.callbacks.manager import AsyncCallbackManagerForToolRun, CallbackManagerForToolRun
4
+ import requests
5
+ import base64
6
+ import os
7
+ import uuid
8
+ from PIL import Image, ImageOps, ImageDraw, ImageFont
9
+
10
+ def optimizationProblem(query):
11
+ query = query +" What's the date today?"
12
+ return query
13
+
14
+ class CustomWeatherTool(BaseTool):
15
+ name = "weather"
16
+ description = "useful for when the input to this tool should be city"
17
+
18
+ def _run(self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> str:
19
+
20
+ return "The weather in "+query
21
+ async def _arun(self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None) -> str:
22
+ """Use the tool asynchronously."""
23
+ raise NotImplementedError("custom_search does not support async")
24
+
25
+
26
+ class Text2Image(BaseTool):
27
+ name = "Generate Image From User Input Text"
28
+ description ="useful when you want to generate an image from a user input text and save it to a file. like: generate an image of an object or something, or generate an image that includes some objects. The input to this tool should be a string, representing the text used to generate image. "
29
+ return_direct=True
30
+ def _run(self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> str:
31
+ url = "http://region-9.seetacloud.com:39487/sdapi/v1/txt2img"
32
+ body = {
33
+ "negative_prompt": "",
34
+ "width": "900",
35
+ "prompt": query,
36
+ "steps": "30",
37
+ "cfg_scale": "8",
38
+ "height": "900"
39
+ }
40
+
41
+ try:
42
+ result = requests.post(url, json=body, stream=True)
43
+ result.raise_for_status() # Raise an exception if request was not successful
44
+ response_data = result.json()
45
+ images_json = response_data["images"]
46
+ if len(images_json) > 0:
47
+ image_data = images_json[0].split(",", 1)[0]
48
+ image_bytes = base64.b64decode(image_data)
49
+ image_filename = os.path.join('image', f"{str(uuid.uuid4())[:8]}.png")
50
+ with open(image_filename, "wb") as file:
51
+ file.write(image_bytes)
52
+
53
+ except requests.exceptions.RequestException as e:
54
+ print("An error occurred:", e)
55
+
56
+ return image_filename
57
+ async def _arun(self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None) -> str:
58
+ """Use the tool asynchronously."""
59
+ raise NotImplementedError("custom_search does not support async")
60
+
61
+
62
+
63
+
64
+ @tool("search", return_direct=True)
65
+ def search_api(query: str) -> str:
66
+ """Searches the API for the query."""
67
+ return f"Results for query {query}"
llmLoader.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ from langchain.llms.base import LLM
3
+ import torch
4
+ from transformers import AutoModel, AutoTokenizer
5
+ from langchain.llms.utils import enforce_stop_tokens
6
+ from fastchat.conversation import (compute_skip_echo_len,
7
+ get_default_conv_template)
8
+
9
+
10
+ class ModelLoader(LLM):
11
+ tokenizer: object = None
12
+ model: object = None
13
+ max_token: int = 10000
14
+ temperature: float = 0.1
15
+ top_p = 0.9
16
+ history = []
17
+
18
+ def __init__(self):
19
+ super().__init__()
20
+
21
+ @property
22
+ def _llm_type(self) -> str:
23
+ return "ChatLLM"
24
+
25
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
26
+ conv = get_default_conv_template("/DATA/gpt/lang/model_cache/THUDM/chatglm-6b-int8").copy()
27
+ conv.append_message(conv.roles[0], prompt)
28
+ conv.append_message(conv.roles[1], None)
29
+ prompt = conv.get_prompt()
30
+ inputs = self.tokenizer([prompt])
31
+ output_ids = self.model.generate(
32
+ torch.as_tensor(inputs.input_ids).cuda(),
33
+ do_sample=True,
34
+ temperature=self.temperature,
35
+ max_new_tokens=self.max_token,
36
+ )
37
+ outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
38
+ skip_echo_len = compute_skip_echo_len("/DATA/gpt/lang/model_cache/THUDM/chatglm-6b-int8", conv, prompt)
39
+ response = outputs[skip_echo_len:]
40
+ if stop is not None:
41
+ response = enforce_stop_tokens(response, stop)
42
+ self.history = [[None, response]]
43
+
44
+ return response
45
+
46
+
47
+ def load_model(self, model_name_or_path: str = "/DATA/gpt/lang/model_cache/THUDM/chatglm-6b-int8"):
48
+
49
+ self.tokenizer = AutoTokenizer.from_pretrained(
50
+ "/DATA/gpt/mingpt-7b/MiniGPT-4-LLaMA-7B",
51
+ trust_remote_code=True
52
+ )
53
+ self.model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
54
+ self.model = self.model.eval()
55
+
loader.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from fastchat.conversation import (compute_skip_echo_len,
6
+ get_default_conv_template)
7
+ from fastchat.serve.inference import load_model as load_fastchat_model
8
+ from langchain.llms.base import LLM
9
+ from langchain.llms.utils import enforce_stop_tokens
10
+ from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
11
+
12
+
13
+ MODEL_CACHE_PATH = os.path.join(os.path.dirname(__file__), 'model_cache')
14
+
15
+ llm_model_dict = {
16
+ "chatglm": {
17
+ "ChatGLM-6B": "THUDM/chatglm-6b",
18
+ "ChatGLM-6B-int4": "THUDM/chatglm-6b-int4",
19
+ "ChatGLM-6B-int8": "THUDM/chatglm-6b-int8",
20
+ "ChatGLM-6b-int4-qe": "THUDM/chatglm-6b-int4-qe"
21
+ },
22
+ "belle": {
23
+ "BELLE-LLaMA-Local": "/pretrainmodel/belle",
24
+ },
25
+ "vicuna": {
26
+ "Vicuna-Local": "/pretrainmodel/vicuna",
27
+ }
28
+ }
29
+
30
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
31
+
32
+ DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
33
+ DEVICE_ID = "0"
34
+ CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
35
+
36
+
37
+ def torch_gc():
38
+ if torch.cuda.is_available():
39
+ with torch.cuda.device(CUDA_DEVICE):
40
+ torch.cuda.empty_cache()
41
+ torch.cuda.ipc_collect()
42
+
43
+
44
+ def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
45
+ num_trans_layers = 28
46
+ per_gpu_layers = 30 / num_gpus
47
+
48
+ device_map = {
49
+ 'transformer.word_embeddings': 0,
50
+ 'transformer.final_layernorm': 0,
51
+ 'lm_head': 0
52
+ }
53
+
54
+ used = 2
55
+ gpu_target = 0
56
+ for i in range(num_trans_layers):
57
+ if used >= per_gpu_layers:
58
+ gpu_target += 1
59
+ used = 0
60
+ assert gpu_target < num_gpus
61
+ device_map[f'transformer.layers.{i}'] = gpu_target
62
+ used += 1
63
+
64
+ return device_map
65
+
66
+
67
+ class ChatLLM(LLM):
68
+ max_token: int = 10000
69
+ temperature: float = 0.1
70
+ top_p = 0.9
71
+ history = []
72
+ model_type: str = "chatglm"
73
+ model_name_or_path: str = "ChatGLM-6B-int4",
74
+ tokenizer: object = None
75
+ model: object = None
76
+
77
+ def __init__(self):
78
+ super().__init__()
79
+
80
+ @property
81
+ def _llm_type(self) -> str:
82
+ return "ChatLLM"
83
+
84
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
85
+
86
+ if self.model_type == 'vicuna':
87
+ conv = get_default_conv_template(self.model_name_or_path).copy()
88
+ conv.append_message(conv.roles[0], prompt)
89
+ conv.append_message(conv.roles[1], None)
90
+ prompt = conv.get_prompt()
91
+ inputs = self.tokenizer([prompt])
92
+ output_ids = self.model.generate(
93
+ torch.as_tensor(inputs.input_ids).cuda(),
94
+ do_sample=True,
95
+ temperature=self.temperature,
96
+ max_new_tokens=self.max_token,
97
+ )
98
+ outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
99
+ skip_echo_len = compute_skip_echo_len(self.model_name_or_path, conv, prompt)
100
+ response = outputs[skip_echo_len:]
101
+ torch_gc()
102
+ if stop is not None:
103
+ response = enforce_stop_tokens(response, stop)
104
+ self.history = [[None, response]]
105
+
106
+ elif self.model_type == 'belle':
107
+ prompt = "Human: "+ prompt +" \n\nAssistant: "
108
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(DEVICE)
109
+ generate_ids = self.model.generate(input_ids, max_new_tokens=self.max_token, do_sample = True, top_k = 30, top_p = self.top_p, temperature = self.temperature, repetition_penalty=1., eos_token_id=2, bos_token_id=1, pad_token_id=0)
110
+ output = self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
111
+ response = output[len(prompt)+1:]
112
+ torch_gc()
113
+ if stop is not None:
114
+ response = enforce_stop_tokens(response, stop)
115
+ self.history = [[None, response]]
116
+
117
+ elif self.model_type == 'chatglm':
118
+ response, _ = self.model.chat(
119
+ self.tokenizer,
120
+ prompt,
121
+ history=self.history,
122
+ max_length=self.max_token,
123
+ temperature=self.temperature,
124
+ )
125
+ torch_gc()
126
+ if stop is not None:
127
+ response = enforce_stop_tokens(response, stop)
128
+ self.history = self.history + [[None, response]]
129
+
130
+ return response
131
+
132
+
133
+ def load_llm(self,
134
+ llm_device=DEVICE,
135
+ num_gpus=torch.cuda.device_count(),
136
+ device_map: Optional[Dict[str, int]] = None,
137
+ **kwargs):
138
+ if 'chatglm' in self.model_name_or_path.lower():
139
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path,
140
+ trust_remote_code=True, cache_dir=os.path.join(MODEL_CACHE_PATH, self.model_name_or_path))
141
+ if torch.cuda.is_available() and llm_device.lower().startswith("cuda"):
142
+
143
+ num_gpus = torch.cuda.device_count()
144
+ if num_gpus < 2 and device_map is None:
145
+ self.model = (AutoModel.from_pretrained(
146
+ self.model_name_or_path, trust_remote_code=True, cache_dir=os.path.join(MODEL_CACHE_PATH, self.model_name_or_path),
147
+ **kwargs).half().cuda())
148
+ else:
149
+ from accelerate import dispatch_model
150
+
151
+ model = AutoModel.from_pretrained(self.model_name_or_path,
152
+ trust_remote_code=True, cache_dir=os.path.join(MODEL_CACHE_PATH, self.model_name_or_path),
153
+ **kwargs).half()
154
+
155
+ if device_map is None:
156
+ device_map = auto_configure_device_map(num_gpus)
157
+
158
+ self.model = dispatch_model(model, device_map=device_map)
159
+ else:
160
+ self.model = (AutoModel.from_pretrained(
161
+ self.model_name_or_path,
162
+ trust_remote_code=True, cache_dir=os.path.join(MODEL_CACHE_PATH, self.model_name_or_path)).float().to(llm_device))
163
+ self.model = self.model.eval()
164
+
165
+ else:
166
+ self.model, self.tokenizer = load_fastchat_model(
167
+ model_path = self.model_name_or_path,
168
+ device = llm_device,
169
+ num_gpus = num_gpus
170
+ )
171
+
singleton.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """The singleton metaclass for ensuring only one instance of a class."""
5
+ import abc
6
+ from typing import Any
7
+
8
+
9
+ class Singleton(abc.ABCMeta, type):
10
+ """Singleton metaclass for ensuring only one instance of a class"""
11
+
12
+ _instances = {}
13
+
14
+ def __call__(cls, *args: Any, **kwargs: Any) -> Any:
15
+ """Call method for the singleton metaclass"""
16
+ if cls not in cls._instances:
17
+ cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
18
+ return cls._instances[cls]
19
+
20
+
21
+ class AbstractSingleton(abc.ABC, metaclass=Singleton):
22
+ """Abstract singleton class for ensuring only one instance of a class"""
23
+
24
+ pass