English
naveensp commited on
Commit
98cfb8f
·
verified ·
1 Parent(s): fbfa1a7

Delete folder serve with huggingface_hub

Browse files
serve/__init__.py DELETED
File without changes
serve/cli.py DELETED
@@ -1,126 +0,0 @@
1
- import argparse
2
- import torch
3
-
4
- from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
5
- from llava.conversation import conv_templates, SeparatorStyle
6
- from llava.model.builder import load_pretrained_model
7
- from llava.utils import disable_torch_init
8
- from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
9
-
10
- from PIL import Image
11
-
12
- import requests
13
- from PIL import Image
14
- from io import BytesIO
15
- from transformers import TextStreamer
16
-
17
-
18
- def load_image(image_file):
19
- if image_file.startswith('http://') or image_file.startswith('https://'):
20
- response = requests.get(image_file)
21
- image = Image.open(BytesIO(response.content)).convert('RGB')
22
- else:
23
- image = Image.open(image_file).convert('RGB')
24
- return image
25
-
26
-
27
- def main(args):
28
- # Model
29
- disable_torch_init()
30
-
31
- model_name = get_model_name_from_path(args.model_path)
32
- tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
33
-
34
- if "llama-2" in model_name.lower():
35
- conv_mode = "llava_llama_2"
36
- elif "mistral" in model_name.lower():
37
- conv_mode = "mistral_instruct"
38
- elif "v1.6-34b" in model_name.lower():
39
- conv_mode = "chatml_direct"
40
- elif "v1" in model_name.lower():
41
- conv_mode = "llava_v1"
42
- elif "mpt" in model_name.lower():
43
- conv_mode = "mpt"
44
- else:
45
- conv_mode = "llava_v0"
46
-
47
- if args.conv_mode is not None and conv_mode != args.conv_mode:
48
- print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
49
- else:
50
- args.conv_mode = conv_mode
51
-
52
- conv = conv_templates[args.conv_mode].copy()
53
- if "mpt" in model_name.lower():
54
- roles = ('user', 'assistant')
55
- else:
56
- roles = conv.roles
57
-
58
- image = load_image(args.image_file)
59
- image_size = image.size
60
- # Similar operation in model_worker.py
61
- image_tensor = process_images([image], image_processor, model.config)
62
- if type(image_tensor) is list:
63
- image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
64
- else:
65
- image_tensor = image_tensor.to(model.device, dtype=torch.float16)
66
-
67
- while True:
68
- try:
69
- inp = input(f"{roles[0]}: ")
70
- except EOFError:
71
- inp = ""
72
- if not inp:
73
- print("exit...")
74
- break
75
-
76
- print(f"{roles[1]}: ", end="")
77
-
78
- if image is not None:
79
- # first message
80
- if model.config.mm_use_im_start_end:
81
- inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
82
- else:
83
- inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
84
- image = None
85
-
86
- conv.append_message(conv.roles[0], inp)
87
- conv.append_message(conv.roles[1], None)
88
- prompt = conv.get_prompt()
89
-
90
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
91
- stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
92
- keywords = [stop_str]
93
- streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
94
-
95
- with torch.inference_mode():
96
- output_ids = model.generate(
97
- input_ids,
98
- images=image_tensor,
99
- image_sizes=[image_size],
100
- do_sample=True if args.temperature > 0 else False,
101
- temperature=args.temperature,
102
- max_new_tokens=args.max_new_tokens,
103
- streamer=streamer,
104
- use_cache=True)
105
-
106
- outputs = tokenizer.decode(output_ids[0]).strip()
107
- conv.messages[-1][-1] = outputs
108
-
109
- if args.debug:
110
- print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
111
-
112
-
113
- if __name__ == "__main__":
114
- parser = argparse.ArgumentParser()
115
- parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
116
- parser.add_argument("--model-base", type=str, default=None)
117
- parser.add_argument("--image-file", type=str, required=True)
118
- parser.add_argument("--device", type=str, default="cuda")
119
- parser.add_argument("--conv-mode", type=str, default=None)
120
- parser.add_argument("--temperature", type=float, default=0.2)
121
- parser.add_argument("--max-new-tokens", type=int, default=512)
122
- parser.add_argument("--load-8bit", action="store_true")
123
- parser.add_argument("--load-4bit", action="store_true")
124
- parser.add_argument("--debug", action="store_true")
125
- args = parser.parse_args()
126
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
serve/controller.py DELETED
@@ -1,298 +0,0 @@
1
- """
2
- A controller manages distributed workers.
3
- It sends worker addresses to clients.
4
- """
5
- import argparse
6
- import asyncio
7
- import dataclasses
8
- from enum import Enum, auto
9
- import json
10
- import logging
11
- import time
12
- from typing import List, Union
13
- import threading
14
-
15
- from fastapi import FastAPI, Request
16
- from fastapi.responses import StreamingResponse
17
- import numpy as np
18
- import requests
19
- import uvicorn
20
-
21
- from llava.constants import CONTROLLER_HEART_BEAT_EXPIRATION
22
- from llava.utils import build_logger, server_error_msg
23
-
24
-
25
- logger = build_logger("controller", "controller.log")
26
-
27
-
28
- class DispatchMethod(Enum):
29
- LOTTERY = auto()
30
- SHORTEST_QUEUE = auto()
31
-
32
- @classmethod
33
- def from_str(cls, name):
34
- if name == "lottery":
35
- return cls.LOTTERY
36
- elif name == "shortest_queue":
37
- return cls.SHORTEST_QUEUE
38
- else:
39
- raise ValueError(f"Invalid dispatch method")
40
-
41
-
42
- @dataclasses.dataclass
43
- class WorkerInfo:
44
- model_names: List[str]
45
- speed: int
46
- queue_length: int
47
- check_heart_beat: bool
48
- last_heart_beat: str
49
-
50
-
51
- def heart_beat_controller(controller):
52
- while True:
53
- time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
54
- controller.remove_stable_workers_by_expiration()
55
-
56
-
57
- class Controller:
58
- def __init__(self, dispatch_method: str):
59
- # Dict[str -> WorkerInfo]
60
- self.worker_info = {}
61
- self.dispatch_method = DispatchMethod.from_str(dispatch_method)
62
-
63
- self.heart_beat_thread = threading.Thread(
64
- target=heart_beat_controller, args=(self,), daemon=True)
65
- self.heart_beat_thread.start()
66
-
67
- logger.info("Init controller")
68
-
69
- def register_worker(self, worker_name: str, check_heart_beat: bool,
70
- worker_status: dict):
71
- if worker_name not in self.worker_info:
72
- logger.info(f"Register a new worker: {worker_name}")
73
- else:
74
- logger.info(f"Register an existing worker: {worker_name}")
75
-
76
- if not worker_status:
77
- worker_status = self.get_worker_status(worker_name)
78
- if not worker_status:
79
- return False
80
-
81
- self.worker_info[worker_name] = WorkerInfo(
82
- worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
83
- check_heart_beat, time.time())
84
-
85
- logger.info(f"Register done: {worker_name}, {worker_status}")
86
- return True
87
-
88
- def get_worker_status(self, worker_name: str):
89
- try:
90
- r = requests.post(worker_name + "/worker_get_status", timeout=5)
91
- except requests.exceptions.RequestException as e:
92
- logger.error(f"Get status fails: {worker_name}, {e}")
93
- return None
94
-
95
- if r.status_code != 200:
96
- logger.error(f"Get status fails: {worker_name}, {r}")
97
- return None
98
-
99
- return r.json()
100
-
101
- def remove_worker(self, worker_name: str):
102
- del self.worker_info[worker_name]
103
-
104
- def refresh_all_workers(self):
105
- old_info = dict(self.worker_info)
106
- self.worker_info = {}
107
-
108
- for w_name, w_info in old_info.items():
109
- if not self.register_worker(w_name, w_info.check_heart_beat, None):
110
- logger.info(f"Remove stale worker: {w_name}")
111
-
112
- def list_models(self):
113
- model_names = set()
114
-
115
- for w_name, w_info in self.worker_info.items():
116
- model_names.update(w_info.model_names)
117
-
118
- return list(model_names)
119
-
120
- def get_worker_address(self, model_name: str):
121
- if self.dispatch_method == DispatchMethod.LOTTERY:
122
- worker_names = []
123
- worker_speeds = []
124
- for w_name, w_info in self.worker_info.items():
125
- if model_name in w_info.model_names:
126
- worker_names.append(w_name)
127
- worker_speeds.append(w_info.speed)
128
- worker_speeds = np.array(worker_speeds, dtype=np.float32)
129
- norm = np.sum(worker_speeds)
130
- if norm < 1e-4:
131
- return ""
132
- worker_speeds = worker_speeds / norm
133
- if True: # Directly return address
134
- pt = np.random.choice(np.arange(len(worker_names)),
135
- p=worker_speeds)
136
- worker_name = worker_names[pt]
137
- return worker_name
138
-
139
- # Check status before returning
140
- while True:
141
- pt = np.random.choice(np.arange(len(worker_names)),
142
- p=worker_speeds)
143
- worker_name = worker_names[pt]
144
-
145
- if self.get_worker_status(worker_name):
146
- break
147
- else:
148
- self.remove_worker(worker_name)
149
- worker_speeds[pt] = 0
150
- norm = np.sum(worker_speeds)
151
- if norm < 1e-4:
152
- return ""
153
- worker_speeds = worker_speeds / norm
154
- continue
155
- return worker_name
156
- elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
157
- worker_names = []
158
- worker_qlen = []
159
- for w_name, w_info in self.worker_info.items():
160
- if model_name in w_info.model_names:
161
- worker_names.append(w_name)
162
- worker_qlen.append(w_info.queue_length / w_info.speed)
163
- if len(worker_names) == 0:
164
- return ""
165
- min_index = np.argmin(worker_qlen)
166
- w_name = worker_names[min_index]
167
- self.worker_info[w_name].queue_length += 1
168
- logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
169
- return w_name
170
- else:
171
- raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
172
-
173
- def receive_heart_beat(self, worker_name: str, queue_length: int):
174
- if worker_name not in self.worker_info:
175
- logger.info(f"Receive unknown heart beat. {worker_name}")
176
- return False
177
-
178
- self.worker_info[worker_name].queue_length = queue_length
179
- self.worker_info[worker_name].last_heart_beat = time.time()
180
- logger.info(f"Receive heart beat. {worker_name}")
181
- return True
182
-
183
- def remove_stable_workers_by_expiration(self):
184
- expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
185
- to_delete = []
186
- for worker_name, w_info in self.worker_info.items():
187
- if w_info.check_heart_beat and w_info.last_heart_beat < expire:
188
- to_delete.append(worker_name)
189
-
190
- for worker_name in to_delete:
191
- self.remove_worker(worker_name)
192
-
193
- def worker_api_generate_stream(self, params):
194
- worker_addr = self.get_worker_address(params["model"])
195
- if not worker_addr:
196
- logger.info(f"no worker: {params['model']}")
197
- ret = {
198
- "text": server_error_msg,
199
- "error_code": 2,
200
- }
201
- yield json.dumps(ret).encode() + b"\0"
202
-
203
- try:
204
- response = requests.post(worker_addr + "/worker_generate_stream",
205
- json=params, stream=True, timeout=5)
206
- for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
207
- if chunk:
208
- yield chunk + b"\0"
209
- except requests.exceptions.RequestException as e:
210
- logger.info(f"worker timeout: {worker_addr}")
211
- ret = {
212
- "text": server_error_msg,
213
- "error_code": 3,
214
- }
215
- yield json.dumps(ret).encode() + b"\0"
216
-
217
-
218
- # Let the controller act as a worker to achieve hierarchical
219
- # management. This can be used to connect isolated sub networks.
220
- def worker_api_get_status(self):
221
- model_names = set()
222
- speed = 0
223
- queue_length = 0
224
-
225
- for w_name in self.worker_info:
226
- worker_status = self.get_worker_status(w_name)
227
- if worker_status is not None:
228
- model_names.update(worker_status["model_names"])
229
- speed += worker_status["speed"]
230
- queue_length += worker_status["queue_length"]
231
-
232
- return {
233
- "model_names": list(model_names),
234
- "speed": speed,
235
- "queue_length": queue_length,
236
- }
237
-
238
-
239
- app = FastAPI()
240
-
241
-
242
- @app.post("/register_worker")
243
- async def register_worker(request: Request):
244
- data = await request.json()
245
- controller.register_worker(
246
- data["worker_name"], data["check_heart_beat"],
247
- data.get("worker_status", None))
248
-
249
-
250
- @app.post("/refresh_all_workers")
251
- async def refresh_all_workers():
252
- models = controller.refresh_all_workers()
253
-
254
-
255
- @app.post("/list_models")
256
- async def list_models():
257
- models = controller.list_models()
258
- return {"models": models}
259
-
260
-
261
- @app.post("/get_worker_address")
262
- async def get_worker_address(request: Request):
263
- data = await request.json()
264
- addr = controller.get_worker_address(data["model"])
265
- return {"address": addr}
266
-
267
-
268
- @app.post("/receive_heart_beat")
269
- async def receive_heart_beat(request: Request):
270
- data = await request.json()
271
- exist = controller.receive_heart_beat(
272
- data["worker_name"], data["queue_length"])
273
- return {"exist": exist}
274
-
275
-
276
- @app.post("/worker_generate_stream")
277
- async def worker_api_generate_stream(request: Request):
278
- params = await request.json()
279
- generator = controller.worker_api_generate_stream(params)
280
- return StreamingResponse(generator)
281
-
282
-
283
- @app.post("/worker_get_status")
284
- async def worker_api_get_status(request: Request):
285
- return controller.worker_api_get_status()
286
-
287
-
288
- if __name__ == "__main__":
289
- parser = argparse.ArgumentParser()
290
- parser.add_argument("--host", type=str, default="localhost")
291
- parser.add_argument("--port", type=int, default=21001)
292
- parser.add_argument("--dispatch-method", type=str, choices=[
293
- "lottery", "shortest_queue"], default="shortest_queue")
294
- args = parser.parse_args()
295
- logger.info(f"args: {args}")
296
-
297
- controller = Controller(args.dispatch_method)
298
- uvicorn.run(app, host=args.host, port=args.port, log_level="info")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
serve/examples/extreme_ironing.jpg DELETED
Binary file (62.6 kB)
 
serve/examples/waterview.jpg DELETED
Binary file (95.5 kB)
 
serve/gradio_web_server.py DELETED
@@ -1,479 +0,0 @@
1
- import argparse
2
- import datetime
3
- import json
4
- import os
5
- import time
6
-
7
- import gradio as gr
8
- import requests
9
-
10
- from llava.conversation import (default_conversation, conv_templates,
11
- SeparatorStyle)
12
- from llava.constants import LOGDIR
13
- from llava.utils import (build_logger, server_error_msg,
14
- violates_moderation, moderation_msg)
15
- import hashlib
16
-
17
-
18
- logger = build_logger("gradio_web_server", "gradio_web_server.log")
19
-
20
- headers = {"User-Agent": "LLaVA Client"}
21
-
22
- no_change_btn = gr.Button()
23
- enable_btn = gr.Button(interactive=True)
24
- disable_btn = gr.Button(interactive=False)
25
-
26
- priority = {
27
- "vicuna-13b": "aaaaaaa",
28
- "koala-13b": "aaaaaab",
29
- }
30
-
31
-
32
- def get_conv_log_filename():
33
- t = datetime.datetime.now()
34
- name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
35
- return name
36
-
37
-
38
- def get_model_list():
39
- ret = requests.post(args.controller_url + "/refresh_all_workers")
40
- assert ret.status_code == 200
41
- ret = requests.post(args.controller_url + "/list_models")
42
- models = ret.json()["models"]
43
- models.sort(key=lambda x: priority.get(x, x))
44
- logger.info(f"Models: {models}")
45
- return models
46
-
47
-
48
- get_window_url_params = """
49
- function() {
50
- const params = new URLSearchParams(window.location.search);
51
- url_params = Object.fromEntries(params);
52
- console.log(url_params);
53
- return url_params;
54
- }
55
- """
56
-
57
-
58
- def load_demo(url_params, request: gr.Request):
59
- logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
60
-
61
- dropdown_update = gr.Dropdown(visible=True)
62
- if "model" in url_params:
63
- model = url_params["model"]
64
- if model in models:
65
- dropdown_update = gr.Dropdown(value=model, visible=True)
66
-
67
- state = default_conversation.copy()
68
- return state, dropdown_update
69
-
70
-
71
- def load_demo_refresh_model_list(request: gr.Request):
72
- logger.info(f"load_demo. ip: {request.client.host}")
73
- models = get_model_list()
74
- state = default_conversation.copy()
75
- dropdown_update = gr.Dropdown(
76
- choices=models,
77
- value=models[0] if len(models) > 0 else ""
78
- )
79
- return state, dropdown_update
80
-
81
-
82
- def vote_last_response(state, vote_type, model_selector, request: gr.Request):
83
- with open(get_conv_log_filename(), "a") as fout:
84
- data = {
85
- "tstamp": round(time.time(), 4),
86
- "type": vote_type,
87
- "model": model_selector,
88
- "state": state.dict(),
89
- "ip": request.client.host,
90
- }
91
- fout.write(json.dumps(data) + "\n")
92
-
93
-
94
- def upvote_last_response(state, model_selector, request: gr.Request):
95
- logger.info(f"upvote. ip: {request.client.host}")
96
- vote_last_response(state, "upvote", model_selector, request)
97
- return ("",) + (disable_btn,) * 3
98
-
99
-
100
- def downvote_last_response(state, model_selector, request: gr.Request):
101
- logger.info(f"downvote. ip: {request.client.host}")
102
- vote_last_response(state, "downvote", model_selector, request)
103
- return ("",) + (disable_btn,) * 3
104
-
105
-
106
- def flag_last_response(state, model_selector, request: gr.Request):
107
- logger.info(f"flag. ip: {request.client.host}")
108
- vote_last_response(state, "flag", model_selector, request)
109
- return ("",) + (disable_btn,) * 3
110
-
111
-
112
- def regenerate(state, image_process_mode, request: gr.Request):
113
- logger.info(f"regenerate. ip: {request.client.host}")
114
- state.messages[-1][-1] = None
115
- prev_human_msg = state.messages[-2]
116
- if type(prev_human_msg[1]) in (tuple, list):
117
- prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
118
- state.skip_next = False
119
- return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
120
-
121
-
122
- def clear_history(request: gr.Request):
123
- logger.info(f"clear_history. ip: {request.client.host}")
124
- state = default_conversation.copy()
125
- return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
126
-
127
-
128
- def add_text(state, text, image, image_process_mode, request: gr.Request):
129
- logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
130
- if len(text) <= 0 and image is None:
131
- state.skip_next = True
132
- return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
133
- if args.moderate:
134
- flagged = violates_moderation(text)
135
- if flagged:
136
- state.skip_next = True
137
- return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
138
- no_change_btn,) * 5
139
-
140
- text = text[:1536] # Hard cut-off
141
- if image is not None:
142
- text = text[:1200] # Hard cut-off for images
143
- if '<image>' not in text:
144
- # text = '<Image><image></Image>' + text
145
- text = text + '\n<image>'
146
- text = (text, image, image_process_mode)
147
- state = default_conversation.copy()
148
- state.append_message(state.roles[0], text)
149
- state.append_message(state.roles[1], None)
150
- state.skip_next = False
151
- return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
152
-
153
-
154
- def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
155
- logger.info(f"http_bot. ip: {request.client.host}")
156
- start_tstamp = time.time()
157
- model_name = model_selector
158
-
159
- if state.skip_next:
160
- # This generate call is skipped due to invalid inputs
161
- yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
162
- return
163
-
164
- if len(state.messages) == state.offset + 2:
165
- # First round of conversation
166
- if "llava" in model_name.lower():
167
- if 'llama-2' in model_name.lower():
168
- template_name = "llava_llama_2"
169
- elif "mistral" in model_name.lower() or "mixtral" in model_name.lower():
170
- if 'orca' in model_name.lower():
171
- template_name = "mistral_orca"
172
- elif 'hermes' in model_name.lower():
173
- template_name = "chatml_direct"
174
- else:
175
- template_name = "mistral_instruct"
176
- elif 'llava-v1.6-34b' in model_name.lower():
177
- template_name = "chatml_direct"
178
- elif "v1" in model_name.lower():
179
- if 'mmtag' in model_name.lower():
180
- template_name = "v1_mmtag"
181
- elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
182
- template_name = "v1_mmtag"
183
- else:
184
- template_name = "llava_v1"
185
- elif "mpt" in model_name.lower():
186
- template_name = "mpt"
187
- else:
188
- if 'mmtag' in model_name.lower():
189
- template_name = "v0_mmtag"
190
- elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
191
- template_name = "v0_mmtag"
192
- else:
193
- template_name = "llava_v0"
194
- elif "mpt" in model_name:
195
- template_name = "mpt_text"
196
- elif "llama-2" in model_name:
197
- template_name = "llama_2"
198
- else:
199
- template_name = "vicuna_v1"
200
- new_state = conv_templates[template_name].copy()
201
- new_state.append_message(new_state.roles[0], state.messages[-2][1])
202
- new_state.append_message(new_state.roles[1], None)
203
- state = new_state
204
-
205
- # Query worker address
206
- controller_url = args.controller_url
207
- ret = requests.post(controller_url + "/get_worker_address",
208
- json={"model": model_name})
209
- worker_addr = ret.json()["address"]
210
- logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
211
-
212
- # No available worker
213
- if worker_addr == "":
214
- state.messages[-1][-1] = server_error_msg
215
- yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
216
- return
217
-
218
- # Construct prompt
219
- prompt = state.get_prompt()
220
-
221
- all_images = state.get_images(return_pil=True)
222
- all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
223
- for image, hash in zip(all_images, all_image_hash):
224
- t = datetime.datetime.now()
225
- filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
226
- if not os.path.isfile(filename):
227
- os.makedirs(os.path.dirname(filename), exist_ok=True)
228
- image.save(filename)
229
-
230
- # Make requests
231
- pload = {
232
- "model": model_name,
233
- "prompt": prompt,
234
- "temperature": float(temperature),
235
- "top_p": float(top_p),
236
- "max_new_tokens": min(int(max_new_tokens), 1536),
237
- "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
238
- "images": f'List of {len(state.get_images())} images: {all_image_hash}',
239
- }
240
- logger.info(f"==== request ====\n{pload}")
241
-
242
- pload['images'] = state.get_images()
243
-
244
- state.messages[-1][-1] = "▌"
245
- yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
246
-
247
- try:
248
- # Stream output
249
- response = requests.post(worker_addr + "/worker_generate_stream",
250
- headers=headers, json=pload, stream=True, timeout=10)
251
- for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
252
- if chunk:
253
- data = json.loads(chunk.decode())
254
- if data["error_code"] == 0:
255
- output = data["text"][len(prompt):].strip()
256
- state.messages[-1][-1] = output + "▌"
257
- yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
258
- else:
259
- output = data["text"] + f" (error_code: {data['error_code']})"
260
- state.messages[-1][-1] = output
261
- yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
262
- return
263
- time.sleep(0.03)
264
- except requests.exceptions.RequestException as e:
265
- state.messages[-1][-1] = server_error_msg
266
- yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
267
- return
268
-
269
- state.messages[-1][-1] = state.messages[-1][-1][:-1]
270
- yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
271
-
272
- finish_tstamp = time.time()
273
- logger.info(f"{output}")
274
-
275
- with open(get_conv_log_filename(), "a") as fout:
276
- data = {
277
- "tstamp": round(finish_tstamp, 4),
278
- "type": "chat",
279
- "model": model_name,
280
- "start": round(start_tstamp, 4),
281
- "finish": round(finish_tstamp, 4),
282
- "state": state.dict(),
283
- "images": all_image_hash,
284
- "ip": request.client.host,
285
- }
286
- fout.write(json.dumps(data) + "\n")
287
-
288
- title_markdown = ("""
289
- # 🌋 LLaVA: Large Language and Vision Assistant
290
- [[Project Page](https://llava-vl.github.io)] [[Code](https://github.com/haotian-liu/LLaVA)] [[Model](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)] | 📚 [[LLaVA](https://arxiv.org/abs/2304.08485)] [[LLaVA-v1.5](https://arxiv.org/abs/2310.03744)] [[LLaVA-v1.6](https://llava-vl.github.io/blog/2024-01-30-llava-1-6/)]
291
- """)
292
-
293
- tos_markdown = ("""
294
- ### Terms of use
295
- By using this service, users are required to agree to the following terms:
296
- The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
297
- Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
298
- For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
299
- """)
300
-
301
-
302
- learn_more_markdown = ("""
303
- ### License
304
- The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
305
- """)
306
-
307
- block_css = """
308
-
309
- #buttons button {
310
- min-width: min(120px,100%);
311
- }
312
-
313
- """
314
-
315
- def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
316
- textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
317
- with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
318
- state = gr.State()
319
-
320
- if not embed_mode:
321
- gr.Markdown(title_markdown)
322
-
323
- with gr.Row():
324
- with gr.Column(scale=3):
325
- with gr.Row(elem_id="model_selector_row"):
326
- model_selector = gr.Dropdown(
327
- choices=models,
328
- value=models[0] if len(models) > 0 else "",
329
- interactive=True,
330
- show_label=False,
331
- container=False)
332
-
333
- imagebox = gr.Image(type="pil")
334
- image_process_mode = gr.Radio(
335
- ["Crop", "Resize", "Pad", "Default"],
336
- value="Default",
337
- label="Preprocess for non-square image", visible=False)
338
-
339
- if cur_dir is None:
340
- cur_dir = os.path.dirname(os.path.abspath(__file__))
341
- gr.Examples(examples=[
342
- [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
343
- [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
344
- ], inputs=[imagebox, textbox])
345
-
346
- with gr.Accordion("Parameters", open=False) as parameter_row:
347
- temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
348
- top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
349
- max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
350
-
351
- with gr.Column(scale=8):
352
- chatbot = gr.Chatbot(
353
- elem_id="chatbot",
354
- label="LLaVA Chatbot",
355
- height=650,
356
- layout="panel",
357
- )
358
- with gr.Row():
359
- with gr.Column(scale=8):
360
- textbox.render()
361
- with gr.Column(scale=1, min_width=50):
362
- submit_btn = gr.Button(value="Send", variant="primary")
363
- with gr.Row(elem_id="buttons") as button_row:
364
- upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
365
- downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
366
- flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
367
- #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
368
- regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
369
- clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
370
-
371
- if not embed_mode:
372
- gr.Markdown(tos_markdown)
373
- gr.Markdown(learn_more_markdown)
374
- url_params = gr.JSON(visible=False)
375
-
376
- # Register listeners
377
- btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
378
- upvote_btn.click(
379
- upvote_last_response,
380
- [state, model_selector],
381
- [textbox, upvote_btn, downvote_btn, flag_btn]
382
- )
383
- downvote_btn.click(
384
- downvote_last_response,
385
- [state, model_selector],
386
- [textbox, upvote_btn, downvote_btn, flag_btn]
387
- )
388
- flag_btn.click(
389
- flag_last_response,
390
- [state, model_selector],
391
- [textbox, upvote_btn, downvote_btn, flag_btn]
392
- )
393
-
394
- regenerate_btn.click(
395
- regenerate,
396
- [state, image_process_mode],
397
- [state, chatbot, textbox, imagebox] + btn_list
398
- ).then(
399
- http_bot,
400
- [state, model_selector, temperature, top_p, max_output_tokens],
401
- [state, chatbot] + btn_list,
402
- concurrency_limit=concurrency_count
403
- )
404
-
405
- clear_btn.click(
406
- clear_history,
407
- None,
408
- [state, chatbot, textbox, imagebox] + btn_list,
409
- queue=False
410
- )
411
-
412
- textbox.submit(
413
- add_text,
414
- [state, textbox, imagebox, image_process_mode],
415
- [state, chatbot, textbox, imagebox] + btn_list,
416
- queue=False
417
- ).then(
418
- http_bot,
419
- [state, model_selector, temperature, top_p, max_output_tokens],
420
- [state, chatbot] + btn_list,
421
- concurrency_limit=concurrency_count
422
- )
423
-
424
- submit_btn.click(
425
- add_text,
426
- [state, textbox, imagebox, image_process_mode],
427
- [state, chatbot, textbox, imagebox] + btn_list
428
- ).then(
429
- http_bot,
430
- [state, model_selector, temperature, top_p, max_output_tokens],
431
- [state, chatbot] + btn_list,
432
- concurrency_limit=concurrency_count
433
- )
434
-
435
- if args.model_list_mode == "once":
436
- demo.load(
437
- load_demo,
438
- [url_params],
439
- [state, model_selector],
440
- js=get_window_url_params
441
- )
442
- elif args.model_list_mode == "reload":
443
- demo.load(
444
- load_demo_refresh_model_list,
445
- None,
446
- [state, model_selector],
447
- queue=False
448
- )
449
- else:
450
- raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
451
-
452
- return demo
453
-
454
-
455
- if __name__ == "__main__":
456
- parser = argparse.ArgumentParser()
457
- parser.add_argument("--host", type=str, default="0.0.0.0")
458
- parser.add_argument("--port", type=int)
459
- parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
460
- parser.add_argument("--concurrency-count", type=int, default=16)
461
- parser.add_argument("--model-list-mode", type=str, default="once",
462
- choices=["once", "reload"])
463
- parser.add_argument("--share", action="store_true")
464
- parser.add_argument("--moderate", action="store_true")
465
- parser.add_argument("--embed", action="store_true")
466
- args = parser.parse_args()
467
- logger.info(f"args: {args}")
468
-
469
- models = get_model_list()
470
-
471
- logger.info(args)
472
- demo = build_demo(args.embed, concurrency_count=args.concurrency_count)
473
- demo.queue(
474
- api_open=False
475
- ).launch(
476
- server_name=args.host,
477
- server_port=args.port,
478
- share=args.share
479
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
serve/model_worker.py DELETED
@@ -1,288 +0,0 @@
1
- """
2
- A model worker executes the model.
3
- """
4
- import argparse
5
- import asyncio
6
- import json
7
- import time
8
- import threading
9
- import uuid
10
-
11
- from fastapi import FastAPI, Request, BackgroundTasks
12
- from fastapi.responses import StreamingResponse
13
- import requests
14
- import torch
15
- import uvicorn
16
- from functools import partial
17
-
18
- from llava.constants import WORKER_HEART_BEAT_INTERVAL
19
- from llava.utils import (build_logger, server_error_msg,
20
- pretty_print_semaphore)
21
- from llava.model.builder import load_pretrained_model
22
- from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token
23
- from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24
- from transformers import TextIteratorStreamer
25
- from threading import Thread
26
-
27
-
28
- GB = 1 << 30
29
-
30
- worker_id = str(uuid.uuid4())[:6]
31
- logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
32
- global_counter = 0
33
-
34
- model_semaphore = None
35
-
36
-
37
- def heart_beat_worker(controller):
38
-
39
- while True:
40
- time.sleep(WORKER_HEART_BEAT_INTERVAL)
41
- controller.send_heart_beat()
42
-
43
-
44
- class ModelWorker:
45
- def __init__(self, controller_addr, worker_addr,
46
- worker_id, no_register,
47
- model_path, model_base, model_name,
48
- load_8bit, load_4bit, device, use_flash_attn=False):
49
- self.controller_addr = controller_addr
50
- self.worker_addr = worker_addr
51
- self.worker_id = worker_id
52
- if model_path.endswith("/"):
53
- model_path = model_path[:-1]
54
- if model_name is None:
55
- model_paths = model_path.split("/")
56
- if model_paths[-1].startswith('checkpoint-'):
57
- self.model_name = model_paths[-2] + "_" + model_paths[-1]
58
- else:
59
- self.model_name = model_paths[-1]
60
- else:
61
- self.model_name = model_name
62
-
63
- self.device = device
64
- logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
65
- self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
66
- model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device, use_flash_attn=use_flash_attn)
67
- self.is_multimodal = 'llava' in self.model_name.lower()
68
-
69
- if not no_register:
70
- self.register_to_controller()
71
- self.heart_beat_thread = threading.Thread(
72
- target=heart_beat_worker, args=(self,), daemon=True)
73
- self.heart_beat_thread.start()
74
-
75
- def register_to_controller(self):
76
- logger.info("Register to controller")
77
-
78
- url = self.controller_addr + "/register_worker"
79
- data = {
80
- "worker_name": self.worker_addr,
81
- "check_heart_beat": True,
82
- "worker_status": self.get_status()
83
- }
84
- r = requests.post(url, json=data)
85
- assert r.status_code == 200
86
-
87
- def send_heart_beat(self):
88
- logger.info(f"Send heart beat. Models: {[self.model_name]}. "
89
- f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
90
- f"global_counter: {global_counter}")
91
-
92
- url = self.controller_addr + "/receive_heart_beat"
93
-
94
- while True:
95
- try:
96
- ret = requests.post(url, json={
97
- "worker_name": self.worker_addr,
98
- "queue_length": self.get_queue_length()}, timeout=5)
99
- exist = ret.json()["exist"]
100
- break
101
- except requests.exceptions.RequestException as e:
102
- logger.error(f"heart beat error: {e}")
103
- time.sleep(5)
104
-
105
- if not exist:
106
- self.register_to_controller()
107
-
108
- def get_queue_length(self):
109
- if model_semaphore is None:
110
- return 0
111
- else:
112
- return args.limit_model_concurrency - model_semaphore._value + (len(
113
- model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
114
-
115
- def get_status(self):
116
- return {
117
- "model_names": [self.model_name],
118
- "speed": 1,
119
- "queue_length": self.get_queue_length(),
120
- }
121
-
122
- @torch.inference_mode()
123
- def generate_stream(self, params):
124
- tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
125
-
126
- prompt = params["prompt"]
127
- ori_prompt = prompt
128
- images = params.get("images", None)
129
- num_image_tokens = 0
130
- if images is not None and len(images) > 0 and self.is_multimodal:
131
- if len(images) > 0:
132
- if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
133
- raise ValueError("Number of images does not match number of <image> tokens in prompt")
134
-
135
- images = [load_image_from_base64(image) for image in images]
136
- image_sizes = [image.size for image in images]
137
- images = process_images(images, image_processor, model.config)
138
-
139
- if type(images) is list:
140
- images = [image.to(self.model.device, dtype=torch.float16) for image in images]
141
- else:
142
- images = images.to(self.model.device, dtype=torch.float16)
143
-
144
- replace_token = DEFAULT_IMAGE_TOKEN
145
- if getattr(self.model.config, 'mm_use_im_start_end', False):
146
- replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
147
- prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
148
-
149
- num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
150
- else:
151
- images = None
152
- image_sizes = None
153
- image_args = {"images": images, "image_sizes": image_sizes}
154
- else:
155
- images = None
156
- image_args = {}
157
-
158
- temperature = float(params.get("temperature", 1.0))
159
- top_p = float(params.get("top_p", 1.0))
160
- max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
161
- max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
162
- stop_str = params.get("stop", None)
163
- do_sample = True if temperature > 0.001 else False
164
-
165
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
166
- keywords = [stop_str]
167
- # stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
168
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
169
-
170
- max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
171
-
172
- if max_new_tokens < 1:
173
- yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
174
- return
175
-
176
- thread = Thread(target=model.generate, kwargs=dict(
177
- inputs=input_ids,
178
- do_sample=do_sample,
179
- temperature=temperature,
180
- top_p=top_p,
181
- max_new_tokens=max_new_tokens,
182
- streamer=streamer,
183
- use_cache=True,
184
- **image_args
185
- ))
186
- thread.start()
187
-
188
- generated_text = ori_prompt
189
- for new_text in streamer:
190
- generated_text += new_text
191
- if generated_text.endswith(stop_str):
192
- generated_text = generated_text[:-len(stop_str)]
193
- yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
194
-
195
- def generate_stream_gate(self, params):
196
- try:
197
- for x in self.generate_stream(params):
198
- yield x
199
- except ValueError as e:
200
- print("Caught ValueError:", e)
201
- ret = {
202
- "text": server_error_msg,
203
- "error_code": 1,
204
- }
205
- yield json.dumps(ret).encode() + b"\0"
206
- except torch.cuda.CudaError as e:
207
- print("Caught torch.cuda.CudaError:", e)
208
- ret = {
209
- "text": server_error_msg,
210
- "error_code": 1,
211
- }
212
- yield json.dumps(ret).encode() + b"\0"
213
- except Exception as e:
214
- print("Caught Unknown Error", e)
215
- ret = {
216
- "text": server_error_msg,
217
- "error_code": 1,
218
- }
219
- yield json.dumps(ret).encode() + b"\0"
220
-
221
-
222
- app = FastAPI()
223
-
224
-
225
- def release_model_semaphore(fn=None):
226
- model_semaphore.release()
227
- if fn is not None:
228
- fn()
229
-
230
-
231
- @app.post("/worker_generate_stream")
232
- async def generate_stream(request: Request):
233
- global model_semaphore, global_counter
234
- global_counter += 1
235
- params = await request.json()
236
-
237
- if model_semaphore is None:
238
- model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
239
- await model_semaphore.acquire()
240
- worker.send_heart_beat()
241
- generator = worker.generate_stream_gate(params)
242
- background_tasks = BackgroundTasks()
243
- background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
244
- return StreamingResponse(generator, background=background_tasks)
245
-
246
-
247
- @app.post("/worker_get_status")
248
- async def get_status(request: Request):
249
- return worker.get_status()
250
-
251
-
252
- if __name__ == "__main__":
253
- parser = argparse.ArgumentParser()
254
- parser.add_argument("--host", type=str, default="localhost")
255
- parser.add_argument("--port", type=int, default=21002)
256
- parser.add_argument("--worker-address", type=str,
257
- default="http://localhost:21002")
258
- parser.add_argument("--controller-address", type=str,
259
- default="http://localhost:21001")
260
- parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
261
- parser.add_argument("--model-base", type=str, default=None)
262
- parser.add_argument("--model-name", type=str)
263
- parser.add_argument("--device", type=str, default="cuda")
264
- parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
265
- parser.add_argument("--limit-model-concurrency", type=int, default=5)
266
- parser.add_argument("--stream-interval", type=int, default=1)
267
- parser.add_argument("--no-register", action="store_true")
268
- parser.add_argument("--load-8bit", action="store_true")
269
- parser.add_argument("--load-4bit", action="store_true")
270
- parser.add_argument("--use-flash-attn", action="store_true")
271
- args = parser.parse_args()
272
- logger.info(f"args: {args}")
273
-
274
- if args.multi_modal:
275
- logger.warning("Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
276
-
277
- worker = ModelWorker(args.controller_address,
278
- args.worker_address,
279
- worker_id,
280
- args.no_register,
281
- args.model_path,
282
- args.model_base,
283
- args.model_name,
284
- args.load_8bit,
285
- args.load_4bit,
286
- args.device,
287
- use_flash_attn=args.use_flash_attn)
288
- uvicorn.run(app, host=args.host, port=args.port, log_level="info")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
serve/register_worker.py DELETED
@@ -1,26 +0,0 @@
1
- """
2
- Manually register workers.
3
-
4
- Usage:
5
- python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
6
- """
7
-
8
- import argparse
9
-
10
- import requests
11
-
12
- if __name__ == "__main__":
13
- parser = argparse.ArgumentParser()
14
- parser.add_argument("--controller-address", type=str)
15
- parser.add_argument("--worker-name", type=str)
16
- parser.add_argument("--check-heart-beat", action="store_true")
17
- args = parser.parse_args()
18
-
19
- url = args.controller_address + "/register_worker"
20
- data = {
21
- "worker_name": args.worker_name,
22
- "check_heart_beat": args.check_heart_beat,
23
- "worker_status": None,
24
- }
25
- r = requests.post(url, json=data)
26
- assert r.status_code == 200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
serve/sglang_worker.py DELETED
@@ -1,244 +0,0 @@
1
- """
2
- A model worker executes the model.
3
- """
4
- import argparse
5
- import asyncio
6
- from concurrent.futures import ThreadPoolExecutor
7
- import json
8
- import time
9
- import threading
10
- import uuid
11
-
12
- from fastapi import FastAPI, Request, BackgroundTasks
13
- from fastapi.responses import StreamingResponse
14
- import requests
15
- import re
16
- import uvicorn
17
- from functools import partial
18
-
19
- from llava.constants import WORKER_HEART_BEAT_INTERVAL
20
- from llava.utils import (build_logger, server_error_msg,
21
- pretty_print_semaphore)
22
- from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, expand2square
23
- from llava.constants import DEFAULT_IMAGE_TOKEN
24
-
25
- import sglang as sgl
26
- from sglang.backend.runtime_endpoint import RuntimeEndpoint
27
-
28
-
29
- GB = 1 << 30
30
-
31
- worker_id = str(uuid.uuid4())[:6]
32
- logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
33
- global_counter = 0
34
-
35
- model_semaphore = None
36
-
37
-
38
- def heart_beat_worker(controller):
39
- while True:
40
- time.sleep(WORKER_HEART_BEAT_INTERVAL)
41
- controller.send_heart_beat()
42
-
43
-
44
- @sgl.function
45
- def pipeline(s, prompt, max_tokens):
46
- for p in prompt:
47
- if type(p) is str:
48
- s += p
49
- else:
50
- s += sgl.image(p)
51
- s += sgl.gen("response", max_tokens=max_tokens)
52
-
53
-
54
- class ModelWorker:
55
- def __init__(self, controller_addr, worker_addr, sgl_endpoint,
56
- worker_id, no_register, model_name):
57
- self.controller_addr = controller_addr
58
- self.worker_addr = worker_addr
59
- self.worker_id = worker_id
60
-
61
- # Select backend
62
- backend = RuntimeEndpoint(sgl_endpoint)
63
- sgl.set_default_backend(backend)
64
- model_path = backend.model_info["model_path"]
65
-
66
- if model_path.endswith("/"):
67
- model_path = model_path[:-1]
68
- if model_name is None:
69
- model_paths = model_path.split("/")
70
- if model_paths[-1].startswith('checkpoint-'):
71
- self.model_name = model_paths[-2] + "_" + model_paths[-1]
72
- else:
73
- self.model_name = model_paths[-1]
74
- else:
75
- self.model_name = model_name
76
-
77
- logger.info(f"Loading the SGLANG model {self.model_name} on worker {worker_id} ...")
78
-
79
- if not no_register:
80
- self.register_to_controller()
81
- self.heart_beat_thread = threading.Thread(
82
- target=heart_beat_worker, args=(self,), daemon=True)
83
- self.heart_beat_thread.start()
84
-
85
- def register_to_controller(self):
86
- logger.info("Register to controller")
87
-
88
- url = self.controller_addr + "/register_worker"
89
- data = {
90
- "worker_name": self.worker_addr,
91
- "check_heart_beat": True,
92
- "worker_status": self.get_status()
93
- }
94
- r = requests.post(url, json=data)
95
- assert r.status_code == 200
96
-
97
- def send_heart_beat(self):
98
- logger.info(f"Send heart beat. Models: {[self.model_name]}. "
99
- f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
100
- f"global_counter: {global_counter}")
101
-
102
- url = self.controller_addr + "/receive_heart_beat"
103
-
104
- while True:
105
- try:
106
- ret = requests.post(url, json={
107
- "worker_name": self.worker_addr,
108
- "queue_length": self.get_queue_length()}, timeout=5)
109
- exist = ret.json()["exist"]
110
- break
111
- except requests.exceptions.RequestException as e:
112
- logger.error(f"heart beat error: {e}")
113
- time.sleep(5)
114
-
115
- if not exist:
116
- self.register_to_controller()
117
-
118
- def get_queue_length(self):
119
- if model_semaphore is None:
120
- return 0
121
- else:
122
- return args.limit_model_concurrency - model_semaphore._value + (len(
123
- model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
124
-
125
- def get_status(self):
126
- return {
127
- "model_names": [self.model_name],
128
- "speed": 1,
129
- "queue_length": self.get_queue_length(),
130
- }
131
-
132
- async def generate_stream(self, params):
133
- ori_prompt = prompt = params["prompt"]
134
- images = params.get("images", None)
135
- if images is not None and len(images) > 0:
136
- if len(images) > 0:
137
- if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
138
- raise ValueError("Number of images does not match number of <image> tokens in prompt")
139
-
140
- images = [load_image_from_base64(image) for image in images]
141
-
142
- # FIXME: for image-start/end token
143
- # replace_token = DEFAULT_IMAGE_TOKEN
144
- # if getattr(self.model.config, 'mm_use_im_start_end', False):
145
- # replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
146
- # prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
147
- prompt = prompt.replace(' ' + DEFAULT_IMAGE_TOKEN + '\n', DEFAULT_IMAGE_TOKEN)
148
- prompt_split = prompt.split(DEFAULT_IMAGE_TOKEN)
149
- prompt = []
150
- for i in range(len(prompt_split)):
151
- prompt.append(prompt_split[i])
152
- if i < len(images):
153
- prompt.append(images[i])
154
- else:
155
- prompt = [prompt]
156
-
157
- temperature = float(params.get("temperature", 1.0))
158
- top_p = float(params.get("top_p", 1.0))
159
- # max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
160
- max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
161
- stop_str = params.get("stop", None)
162
- stop_str = [stop_str] if stop_str is not None else None
163
-
164
- print({'prompt': prompt, 'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p})
165
- state = pipeline.run(prompt, max_new_tokens, temperature=temperature, top_p=top_p, stream=True)
166
-
167
- generated_text = ori_prompt
168
- async for text_outputs in state.text_async_iter(var_name="response"):
169
- generated_text += text_outputs
170
- yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
171
-
172
- async def generate_stream_gate(self, params):
173
- try:
174
- async for x in self.generate_stream(params):
175
- yield x
176
- except ValueError as e:
177
- print("Caught ValueError:", e)
178
- ret = {
179
- "text": server_error_msg,
180
- "error_code": 1,
181
- }
182
- yield json.dumps(ret).encode() + b"\0"
183
- except Exception as e:
184
- print("Caught Unknown Error", e)
185
- ret = {
186
- "text": server_error_msg,
187
- "error_code": 1,
188
- }
189
- yield json.dumps(ret).encode() + b"\0"
190
-
191
-
192
- app = FastAPI()
193
-
194
-
195
- def release_model_semaphore(fn=None):
196
- model_semaphore.release()
197
- if fn is not None:
198
- fn()
199
-
200
-
201
- @app.post("/worker_generate_stream")
202
- async def generate_stream(request: Request):
203
- global model_semaphore, global_counter
204
- global_counter += 1
205
- params = await request.json()
206
-
207
- if model_semaphore is None:
208
- model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
209
- await model_semaphore.acquire()
210
- worker.send_heart_beat()
211
- generator = worker.generate_stream_gate(params)
212
- background_tasks = BackgroundTasks()
213
- background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
214
- return StreamingResponse(generator, background=background_tasks)
215
-
216
-
217
- @app.post("/worker_get_status")
218
- async def get_status(request: Request):
219
- return worker.get_status()
220
-
221
-
222
- if __name__ == "__main__":
223
- parser = argparse.ArgumentParser()
224
- parser.add_argument("--host", type=str, default="localhost")
225
- parser.add_argument("--port", type=int, default=21002)
226
- parser.add_argument("--worker-address", type=str,
227
- default="http://localhost:21002")
228
- parser.add_argument("--controller-address", type=str,
229
- default="http://localhost:21001")
230
- parser.add_argument("--model-name", type=str)
231
- parser.add_argument("--sgl-endpoint", type=str)
232
- parser.add_argument("--limit-model-concurrency", type=int, default=5)
233
- parser.add_argument("--stream-interval", type=int, default=1)
234
- parser.add_argument("--no-register", action="store_true")
235
- args = parser.parse_args()
236
- logger.info(f"args: {args}")
237
-
238
- worker = ModelWorker(args.controller_address,
239
- args.worker_address,
240
- args.sgl_endpoint,
241
- worker_id,
242
- args.no_register,
243
- args.model_name)
244
- uvicorn.run(app, host=args.host, port=args.port, log_level="info")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
serve/test_message.py DELETED
@@ -1,62 +0,0 @@
1
- import argparse
2
- import json
3
-
4
- import requests
5
-
6
- from llava.conversation import default_conversation
7
-
8
-
9
- def main():
10
- if args.worker_address:
11
- worker_addr = args.worker_address
12
- else:
13
- controller_addr = args.controller_address
14
- ret = requests.post(controller_addr + "/refresh_all_workers")
15
- ret = requests.post(controller_addr + "/list_models")
16
- models = ret.json()["models"]
17
- models.sort()
18
- print(f"Models: {models}")
19
-
20
- ret = requests.post(controller_addr + "/get_worker_address",
21
- json={"model": args.model_name})
22
- worker_addr = ret.json()["address"]
23
- print(f"worker_addr: {worker_addr}")
24
-
25
- if worker_addr == "":
26
- return
27
-
28
- conv = default_conversation.copy()
29
- conv.append_message(conv.roles[0], args.message)
30
- prompt = conv.get_prompt()
31
-
32
- headers = {"User-Agent": "LLaVA Client"}
33
- pload = {
34
- "model": args.model_name,
35
- "prompt": prompt,
36
- "max_new_tokens": args.max_new_tokens,
37
- "temperature": 0.7,
38
- "stop": conv.sep,
39
- }
40
- response = requests.post(worker_addr + "/worker_generate_stream", headers=headers,
41
- json=pload, stream=True)
42
-
43
- print(prompt.replace(conv.sep, "\n"), end="")
44
- for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
45
- if chunk:
46
- data = json.loads(chunk.decode("utf-8"))
47
- output = data["text"].split(conv.sep)[-1]
48
- print(output, end="\r")
49
- print("")
50
-
51
-
52
- if __name__ == "__main__":
53
- parser = argparse.ArgumentParser()
54
- parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
55
- parser.add_argument("--worker-address", type=str)
56
- parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
57
- parser.add_argument("--max-new-tokens", type=int, default=32)
58
- parser.add_argument("--message", type=str, default=
59
- "Tell me a story with more than 1000 words.")
60
- args = parser.parse_args()
61
-
62
- main()