BenkHel commited on
Commit
bcb804d
·
verified ·
1 Parent(s): d1f015b

Upload 12 files

Browse files
cumo/serve/__init__.py ADDED
File without changes
cumo/serve/app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import argparse
4
+ import time
5
+ import subprocess
6
+
7
+ import cumo.serve.gradio_web_server as gws
8
+
9
+ # Execute the pip install command with additional options
10
+ #subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'flash-attn', '--no-build-isolation', '-U'])
11
+
12
+ def start_controller():
13
+ print("Starting the controller")
14
+ controller_command = [
15
+ sys.executable,
16
+ "-m",
17
+ "cumo.serve.controller",
18
+ "--host",
19
+ "0.0.0.0",
20
+ "--port",
21
+ "10000",
22
+ ]
23
+ print(controller_command)
24
+ return subprocess.Popen(controller_command)
25
+
26
+ def start_worker(model_path: str, bits=16):
27
+ print(f"Starting the model worker for the model {model_path}")
28
+ model_name = model_path.strip("/").split("/")[-1]
29
+ assert bits in [4, 8, 16], "It can be only loaded with 16-bit, 8-bit, and 4-bit."
30
+ if bits != 16:
31
+ model_name += f"-{bits}bit"
32
+ worker_command = [
33
+ sys.executable,
34
+ "-m",
35
+ "cumo.serve.model_worker",
36
+ "--host",
37
+ "0.0.0.0",
38
+ "--controller",
39
+ "http://localhost:10000",
40
+ "--model-path",
41
+ model_path,
42
+ "--model-name",
43
+ model_name,
44
+ "--use-flash-attn",
45
+ ]
46
+ if bits != 16:
47
+ worker_command += [f"--load-{bits}bit"]
48
+ print(worker_command)
49
+ return subprocess.Popen(worker_command)
50
+
51
+
52
+ if __name__ == "__main__":
53
+ parser = argparse.ArgumentParser()
54
+ parser.add_argument("--host", type=str, default="0.0.0.0")
55
+ parser.add_argument("--port", type=int)
56
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
57
+ parser.add_argument("--model-base", type=str, default=None)
58
+ parser.add_argument("--controller-url", type=str, default="http://localhost:10000")
59
+ parser.add_argument("--concurrency-count", type=int, default=5)
60
+ parser.add_argument("--bits", type=int, default=16)
61
+ parser.add_argument("--model-list-mode", type=str, default="reload", choices=["once", "reload"])
62
+ parser.add_argument("--share", action="store_true")
63
+ parser.add_argument("--moderate", action="store_true")
64
+ parser.add_argument("--embed", action="store_true")
65
+ gws.args = parser.parse_args()
66
+ gws.models = []
67
+
68
+ gws.title_markdown += """
69
+ ONLY WORKS WITH GPU! By default, we load the model with 4-bit quantization to make it fit in smaller hardwares.
70
+ """
71
+
72
+ print(f"args: {gws.args}")
73
+ model_path = gws.args.model_path
74
+ print(model_path)
75
+ bits = gws.args.bits
76
+ print(bits)
77
+ concurrency_count = int(os.getenv("concurrency_count", 5))
78
+
79
+ controller_proc = start_controller()
80
+ worker_proc = start_worker(model_path, bits=bits)
81
+
82
+ # Wait for worker and controller to start
83
+ time.sleep(10)
84
+
85
+ exit_status = 0
86
+ try:
87
+ demo = gws.build_demo(embed_mode=False, concurrency_count=concurrency_count)
88
+ demo.queue(
89
+ status_update_rate=10,
90
+ api_open=False
91
+ ).launch(
92
+ server_name="[::]",
93
+ server_port=gws.args.port,
94
+ share=gws.args.share
95
+ )
96
+
97
+ except Exception as e:
98
+ print(e)
99
+ exit_status = 1
100
+ finally:
101
+ worker_proc.kill()
102
+ controller_proc.kill()
103
+
104
+ sys.exit(exit_status)
cumo/serve/cli.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+
4
+ from cumo.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
5
+ from cumo.conversation import conv_templates, SeparatorStyle
6
+ from cumo.model.builder import load_pretrained_model
7
+ from cumo.utils import disable_torch_init
8
+ from cumo.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
9
+
10
+ from PIL import Image
11
+
12
+ import requests
13
+ from PIL import Image
14
+ from io import BytesIO
15
+ from transformers import TextStreamer
16
+
17
+
18
+ def load_image(image_file):
19
+ if image_file.startswith('http://') or image_file.startswith('https://'):
20
+ response = requests.get(image_file)
21
+ image = Image.open(BytesIO(response.content)).convert('RGB')
22
+ else:
23
+ image = Image.open(image_file).convert('RGB')
24
+ return image
25
+
26
+
27
+ def main(args):
28
+ # Model
29
+ disable_torch_init()
30
+
31
+ model_name = get_model_name_from_path(args.model_path)
32
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
33
+ model.config.training = False
34
+
35
+ conv = conv_templates[args.conv_mode].copy()
36
+ if "mpt" in model_name.lower():
37
+ roles = ('user', 'assistant')
38
+ else:
39
+ roles = conv.roles
40
+
41
+ image = load_image(args.image_file)
42
+ image_size = image.size
43
+ # Similar operation in model_worker.py
44
+ image_tensor = process_images([image], image_processor, model.config)
45
+ if type(image_tensor) is list:
46
+ image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
47
+ else:
48
+ image_tensor = image_tensor.to(model.device, dtype=torch.float16)
49
+
50
+ while True:
51
+ try:
52
+ inp = input(f"{roles[0]}: ")
53
+ except EOFError:
54
+ inp = ""
55
+ if not inp:
56
+ print("exit...")
57
+ break
58
+
59
+ print(f"{roles[1]}: ", end="")
60
+
61
+ if image is not None:
62
+ # first message
63
+ if model.config.mm_use_im_start_end:
64
+ inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
65
+ else:
66
+ inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
67
+ conv.append_message(conv.roles[0], inp)
68
+ image = None
69
+ else:
70
+ # later messages
71
+ conv.append_message(conv.roles[0], inp)
72
+ conv.append_message(conv.roles[1], None)
73
+ prompt = conv.get_prompt()
74
+
75
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
76
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
77
+ keywords = [stop_str]
78
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
79
+
80
+ with torch.inference_mode():
81
+ output_ids = model.generate(
82
+ input_ids,
83
+ images=image_tensor,
84
+ image_sizes=[image_size],
85
+ do_sample=True if args.temperature > 0 else False,
86
+ temperature=args.temperature,
87
+ max_new_tokens=args.max_new_tokens,
88
+ streamer=streamer,
89
+ pad_token_id=tokenizer.eos_token_id,
90
+ use_cache=True)
91
+
92
+ outputs = tokenizer.decode(output_ids[0]).strip()
93
+ conv.messages[-1][-1] = outputs
94
+
95
+ if args.debug:
96
+ print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
97
+
98
+
99
+ if __name__ == "__main__":
100
+ parser = argparse.ArgumentParser()
101
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
102
+ parser.add_argument("--model-base", type=str, default=None)
103
+ parser.add_argument("--image-file", type=str, required=True)
104
+ parser.add_argument("--device", type=str, default="cuda")
105
+ parser.add_argument("--conv-mode", type=str, default="mistral_instruct_system")
106
+ parser.add_argument("--temperature", type=float, default=0.2)
107
+ parser.add_argument("--max-new-tokens", type=int, default=512)
108
+ parser.add_argument("--load-8bit", action="store_true")
109
+ parser.add_argument("--load-4bit", action="store_true")
110
+ parser.add_argument("--debug", action="store_true")
111
+ args = parser.parse_args()
112
+ main(args)
cumo/serve/controller.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A controller manages distributed workers.
3
+ It sends worker addresses to clients.
4
+ """
5
+ import argparse
6
+ import asyncio
7
+ import dataclasses
8
+ from enum import Enum, auto
9
+ import json
10
+ import logging
11
+ import time
12
+ from typing import List, Union
13
+ import threading
14
+
15
+ from fastapi import FastAPI, Request
16
+ from fastapi.responses import StreamingResponse
17
+ import numpy as np
18
+ import requests
19
+ import uvicorn
20
+
21
+ from cumo.constants import CONTROLLER_HEART_BEAT_EXPIRATION
22
+ from cumo.utils import build_logger, server_error_msg
23
+
24
+
25
+ logger = build_logger("controller", "controller.log")
26
+
27
+
28
+ class DispatchMethod(Enum):
29
+ LOTTERY = auto()
30
+ SHORTEST_QUEUE = auto()
31
+
32
+ @classmethod
33
+ def from_str(cls, name):
34
+ if name == "lottery":
35
+ return cls.LOTTERY
36
+ elif name == "shortest_queue":
37
+ return cls.SHORTEST_QUEUE
38
+ else:
39
+ raise ValueError(f"Invalid dispatch method")
40
+
41
+
42
+ @dataclasses.dataclass
43
+ class WorkerInfo:
44
+ model_names: List[str]
45
+ speed: int
46
+ queue_length: int
47
+ check_heart_beat: bool
48
+ last_heart_beat: str
49
+
50
+
51
+ def heart_beat_controller(controller):
52
+ while True:
53
+ time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
54
+ controller.remove_stable_workers_by_expiration()
55
+
56
+
57
+ class Controller:
58
+ def __init__(self, dispatch_method: str):
59
+ # Dict[str -> WorkerInfo]
60
+ self.worker_info = {}
61
+ self.dispatch_method = DispatchMethod.from_str(dispatch_method)
62
+
63
+ self.heart_beat_thread = threading.Thread(
64
+ target=heart_beat_controller, args=(self,), daemon=True)
65
+ self.heart_beat_thread.start()
66
+
67
+ logger.info("Init controller")
68
+
69
+ def register_worker(self, worker_name: str, check_heart_beat: bool,
70
+ worker_status: dict):
71
+ if worker_name not in self.worker_info:
72
+ logger.info(f"Register a new worker: {worker_name}")
73
+ else:
74
+ logger.info(f"Register an existing worker: {worker_name}")
75
+
76
+ if not worker_status:
77
+ worker_status = self.get_worker_status(worker_name)
78
+ if not worker_status:
79
+ return False
80
+
81
+ self.worker_info[worker_name] = WorkerInfo(
82
+ worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
83
+ check_heart_beat, time.time())
84
+
85
+ logger.info(f"Register done: {worker_name}, {worker_status}")
86
+ return True
87
+
88
+ def get_worker_status(self, worker_name: str):
89
+ try:
90
+ r = requests.post(worker_name + "/worker_get_status", timeout=5)
91
+ except requests.exceptions.RequestException as e:
92
+ logger.error(f"Get status fails: {worker_name}, {e}")
93
+ return None
94
+
95
+ if r.status_code != 200:
96
+ logger.error(f"Get status fails: {worker_name}, {r}")
97
+ return None
98
+
99
+ return r.json()
100
+
101
+ def remove_worker(self, worker_name: str):
102
+ del self.worker_info[worker_name]
103
+
104
+ def refresh_all_workers(self):
105
+ old_info = dict(self.worker_info)
106
+ self.worker_info = {}
107
+
108
+ for w_name, w_info in old_info.items():
109
+ if not self.register_worker(w_name, w_info.check_heart_beat, None):
110
+ logger.info(f"Remove stale worker: {w_name}")
111
+
112
+ def list_models(self):
113
+ model_names = set()
114
+
115
+ for w_name, w_info in self.worker_info.items():
116
+ model_names.update(w_info.model_names)
117
+
118
+ return list(model_names)
119
+
120
+ def get_worker_address(self, model_name: str):
121
+ if self.dispatch_method == DispatchMethod.LOTTERY:
122
+ worker_names = []
123
+ worker_speeds = []
124
+ for w_name, w_info in self.worker_info.items():
125
+ if model_name in w_info.model_names:
126
+ worker_names.append(w_name)
127
+ worker_speeds.append(w_info.speed)
128
+ worker_speeds = np.array(worker_speeds, dtype=np.float32)
129
+ norm = np.sum(worker_speeds)
130
+ if norm < 1e-4:
131
+ return ""
132
+ worker_speeds = worker_speeds / norm
133
+ if True: # Directly return address
134
+ pt = np.random.choice(np.arange(len(worker_names)),
135
+ p=worker_speeds)
136
+ worker_name = worker_names[pt]
137
+ return worker_name
138
+
139
+ # Check status before returning
140
+ while True:
141
+ pt = np.random.choice(np.arange(len(worker_names)),
142
+ p=worker_speeds)
143
+ worker_name = worker_names[pt]
144
+
145
+ if self.get_worker_status(worker_name):
146
+ break
147
+ else:
148
+ self.remove_worker(worker_name)
149
+ worker_speeds[pt] = 0
150
+ norm = np.sum(worker_speeds)
151
+ if norm < 1e-4:
152
+ return ""
153
+ worker_speeds = worker_speeds / norm
154
+ continue
155
+ return worker_name
156
+ elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
157
+ worker_names = []
158
+ worker_qlen = []
159
+ for w_name, w_info in self.worker_info.items():
160
+ if model_name in w_info.model_names:
161
+ worker_names.append(w_name)
162
+ worker_qlen.append(w_info.queue_length / w_info.speed)
163
+ if len(worker_names) == 0:
164
+ return ""
165
+ min_index = np.argmin(worker_qlen)
166
+ w_name = worker_names[min_index]
167
+ self.worker_info[w_name].queue_length += 1
168
+ logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
169
+ return w_name
170
+ else:
171
+ raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
172
+
173
+ def receive_heart_beat(self, worker_name: str, queue_length: int):
174
+ if worker_name not in self.worker_info:
175
+ logger.info(f"Receive unknown heart beat. {worker_name}")
176
+ return False
177
+
178
+ self.worker_info[worker_name].queue_length = queue_length
179
+ self.worker_info[worker_name].last_heart_beat = time.time()
180
+ logger.info(f"Receive heart beat. {worker_name}")
181
+ return True
182
+
183
+ def remove_stable_workers_by_expiration(self):
184
+ expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
185
+ to_delete = []
186
+ for worker_name, w_info in self.worker_info.items():
187
+ if w_info.check_heart_beat and w_info.last_heart_beat < expire:
188
+ to_delete.append(worker_name)
189
+
190
+ for worker_name in to_delete:
191
+ self.remove_worker(worker_name)
192
+
193
+ def worker_api_generate_stream(self, params):
194
+ worker_addr = self.get_worker_address(params["model"])
195
+ if not worker_addr:
196
+ logger.info(f"no worker: {params['model']}")
197
+ ret = {
198
+ "text": server_error_msg,
199
+ "error_code": 2,
200
+ }
201
+ yield json.dumps(ret).encode() + b"\0"
202
+
203
+ try:
204
+ response = requests.post(worker_addr + "/worker_generate_stream",
205
+ json=params, stream=True, timeout=5)
206
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
207
+ if chunk:
208
+ yield chunk + b"\0"
209
+ except requests.exceptions.RequestException as e:
210
+ logger.info(f"worker timeout: {worker_addr}")
211
+ ret = {
212
+ "text": server_error_msg,
213
+ "error_code": 3,
214
+ }
215
+ yield json.dumps(ret).encode() + b"\0"
216
+
217
+
218
+ # Let the controller act as a worker to achieve hierarchical
219
+ # management. This can be used to connect isolated sub networks.
220
+ def worker_api_get_status(self):
221
+ model_names = set()
222
+ speed = 0
223
+ queue_length = 0
224
+
225
+ for w_name in self.worker_info:
226
+ worker_status = self.get_worker_status(w_name)
227
+ if worker_status is not None:
228
+ model_names.update(worker_status["model_names"])
229
+ speed += worker_status["speed"]
230
+ queue_length += worker_status["queue_length"]
231
+
232
+ return {
233
+ "model_names": list(model_names),
234
+ "speed": speed,
235
+ "queue_length": queue_length,
236
+ }
237
+
238
+
239
+ app = FastAPI()
240
+
241
+
242
+ @app.post("/register_worker")
243
+ async def register_worker(request: Request):
244
+ data = await request.json()
245
+ controller.register_worker(
246
+ data["worker_name"], data["check_heart_beat"],
247
+ data.get("worker_status", None))
248
+
249
+
250
+ @app.post("/refresh_all_workers")
251
+ async def refresh_all_workers():
252
+ models = controller.refresh_all_workers()
253
+
254
+
255
+ @app.post("/list_models")
256
+ async def list_models():
257
+ models = controller.list_models()
258
+ return {"models": models}
259
+
260
+
261
+ @app.post("/get_worker_address")
262
+ async def get_worker_address(request: Request):
263
+ data = await request.json()
264
+ addr = controller.get_worker_address(data["model"])
265
+ return {"address": addr}
266
+
267
+
268
+ @app.post("/receive_heart_beat")
269
+ async def receive_heart_beat(request: Request):
270
+ data = await request.json()
271
+ exist = controller.receive_heart_beat(
272
+ data["worker_name"], data["queue_length"])
273
+ return {"exist": exist}
274
+
275
+
276
+ @app.post("/worker_generate_stream")
277
+ async def worker_api_generate_stream(request: Request):
278
+ params = await request.json()
279
+ generator = controller.worker_api_generate_stream(params)
280
+ return StreamingResponse(generator)
281
+
282
+
283
+ @app.post("/worker_get_status")
284
+ async def worker_api_get_status(request: Request):
285
+ return controller.worker_api_get_status()
286
+
287
+
288
+ if __name__ == "__main__":
289
+ parser = argparse.ArgumentParser()
290
+ parser.add_argument("--host", type=str, default="localhost")
291
+ parser.add_argument("--port", type=int, default=21001)
292
+ parser.add_argument("--dispatch-method", type=str, choices=[
293
+ "lottery", "shortest_queue"], default="shortest_queue")
294
+ args = parser.parse_args()
295
+ logger.info(f"args: {args}")
296
+
297
+ controller = Controller(args.dispatch_method)
298
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
cumo/serve/examples/animal.webp ADDED

Git LFS Details

  • SHA256: dc4ad630f92b8cf7f18bbd639e68b1f8d8004347d67397ba91a9f970509e04d8
  • Pointer size: 131 Bytes
  • Size of remote file: 325 kB
cumo/serve/examples/aveger.jpg ADDED

Git LFS Details

  • SHA256: b4d8cc906cef81beaa5e7cd3372ecbc2de35969a601f0c905d1d7d7426910f3d
  • Pointer size: 131 Bytes
  • Size of remote file: 599 kB
cumo/serve/examples/disney.jpeg ADDED

Git LFS Details

  • SHA256: ee40a1e732e009ebb3c4e7b103c5a9bec6d4b06b27813fa651955a9a3dc22449
  • Pointer size: 131 Bytes
  • Size of remote file: 106 kB
cumo/serve/examples/fridge.webp ADDED
cumo/serve/gradio_web_server.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import json
4
+ import os
5
+ import time
6
+
7
+ import gradio as gr
8
+ import requests
9
+
10
+ from cumo.conversation import (default_conversation, conv_templates,
11
+ SeparatorStyle)
12
+ from cumo.constants import LOGDIR
13
+ from cumo.utils import (build_logger, server_error_msg,
14
+ violates_moderation, moderation_msg)
15
+ import hashlib
16
+
17
+
18
+ logger = build_logger("gradio_web_server", "gradio_web_server.log")
19
+
20
+ headers = {"User-Agent": "CuMo"}
21
+
22
+ no_change_btn = gr.Button()
23
+ enable_btn = gr.Button(interactive=True)
24
+ disable_btn = gr.Button(interactive=False)
25
+
26
+ priority = {
27
+ "mistral-7b": "aaaaaaa",
28
+ "mistral-8x7b": "aaaaaab",
29
+ }
30
+
31
+
32
+ def get_conv_log_filename():
33
+ t = datetime.datetime.now()
34
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
35
+ return name
36
+
37
+
38
+ def get_model_list():
39
+ ret = requests.post(args.controller_url + "/refresh_all_workers")
40
+ assert ret.status_code == 200
41
+ ret = requests.post(args.controller_url + "/list_models")
42
+ models = ret.json()["models"]
43
+ models.sort(key=lambda x: priority.get(x, x))
44
+ logger.info(f"Models: {models}")
45
+ return models
46
+
47
+
48
+ get_window_url_params = """
49
+ function() {
50
+ const params = new URLSearchParams(window.location.search);
51
+ url_params = Object.fromEntries(params);
52
+ console.log(url_params);
53
+ return url_params;
54
+ }
55
+ """
56
+
57
+
58
+ def load_demo(url_params, request: gr.Request):
59
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
60
+
61
+ dropdown_update = gr.Dropdown(visible=True)
62
+ if "model" in url_params:
63
+ model = url_params["model"]
64
+ if model in models:
65
+ dropdown_update = gr.Dropdown(value=model, visible=True)
66
+
67
+ state = default_conversation.copy()
68
+ return state, dropdown_update
69
+
70
+
71
+ def load_demo_refresh_model_list(request: gr.Request):
72
+ logger.info(f"load_demo. ip: {request.client.host}")
73
+ models = get_model_list()
74
+ state = default_conversation.copy()
75
+ dropdown_update = gr.Dropdown(
76
+ choices=models,
77
+ value=models[0] if len(models) > 0 else ""
78
+ )
79
+ return state, dropdown_update
80
+
81
+
82
+ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
83
+ with open(get_conv_log_filename(), "a") as fout:
84
+ data = {
85
+ "tstamp": round(time.time(), 4),
86
+ "type": vote_type,
87
+ "model": model_selector,
88
+ "state": state.dict(),
89
+ "ip": request.client.host,
90
+ }
91
+ fout.write(json.dumps(data) + "\n")
92
+
93
+
94
+ def upvote_last_response(state, model_selector, request: gr.Request):
95
+ logger.info(f"upvote. ip: {request.client.host}")
96
+ vote_last_response(state, "upvote", model_selector, request)
97
+ return ("",) + (disable_btn,) * 3
98
+
99
+
100
+ def downvote_last_response(state, model_selector, request: gr.Request):
101
+ logger.info(f"downvote. ip: {request.client.host}")
102
+ vote_last_response(state, "downvote", model_selector, request)
103
+ return ("",) + (disable_btn,) * 3
104
+
105
+
106
+ def flag_last_response(state, model_selector, request: gr.Request):
107
+ logger.info(f"flag. ip: {request.client.host}")
108
+ vote_last_response(state, "flag", model_selector, request)
109
+ return ("",) + (disable_btn,) * 3
110
+
111
+
112
+ def regenerate(state, image_process_mode, request: gr.Request):
113
+ logger.info(f"regenerate. ip: {request.client.host}")
114
+ state.messages[-1][-1] = None
115
+ prev_human_msg = state.messages[-2]
116
+ if type(prev_human_msg[1]) in (tuple, list):
117
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
118
+ state.skip_next = False
119
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
120
+
121
+
122
+ def clear_history(request: gr.Request):
123
+ logger.info(f"clear_history. ip: {request.client.host}")
124
+ state = default_conversation.copy()
125
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
126
+
127
+
128
+ def add_text(state, text, image, image_process_mode, request: gr.Request):
129
+ logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
130
+ if len(text) <= 0 and image is None:
131
+ state.skip_next = True
132
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
133
+ if args.moderate:
134
+ flagged = violates_moderation(text)
135
+ if flagged:
136
+ state.skip_next = True
137
+ return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
138
+ no_change_btn,) * 5
139
+
140
+ text = text[:1536] # Hard cut-off
141
+ if image is not None:
142
+ text = text[:1200] # Hard cut-off for images
143
+ if '<image>' not in text:
144
+ # text = '<Image><image></Image>' + text
145
+ text = text + '\n<image>'
146
+ text = (text, image, image_process_mode)
147
+ state = default_conversation.copy()
148
+ state.append_message(state.roles[0], text)
149
+ state.append_message(state.roles[1], None)
150
+ state.skip_next = False
151
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
152
+
153
+
154
+ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
155
+ logger.info(f"http_bot. ip: {request.client.host}")
156
+ start_tstamp = time.time()
157
+ model_name = model_selector
158
+
159
+ if state.skip_next:
160
+ # This generate call is skipped due to invalid inputs
161
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
162
+ return
163
+
164
+ if len(state.messages) == state.offset + 2:
165
+ # First round of conversation
166
+ template_name = "mistral_instruct_system"
167
+ new_state = conv_templates[template_name].copy()
168
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
169
+ new_state.append_message(new_state.roles[1], None)
170
+ state = new_state
171
+
172
+ # Query worker address
173
+ controller_url = args.controller_url
174
+ ret = requests.post(controller_url + "/get_worker_address",
175
+ json={"model": model_name})
176
+ worker_addr = ret.json()["address"]
177
+ logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
178
+
179
+ # No available worker
180
+ if worker_addr == "":
181
+ state.messages[-1][-1] = server_error_msg
182
+ yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
183
+ return
184
+
185
+ # Construct prompt
186
+ prompt = state.get_prompt()
187
+
188
+ all_images = state.get_images(return_pil=True)
189
+ all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
190
+ for image, hash in zip(all_images, all_image_hash):
191
+ t = datetime.datetime.now()
192
+ filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
193
+ if not os.path.isfile(filename):
194
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
195
+ image.save(filename)
196
+
197
+ # Make requests
198
+ pload = {
199
+ "model": model_name,
200
+ "prompt": prompt,
201
+ "temperature": float(temperature),
202
+ "top_p": float(top_p),
203
+ "max_new_tokens": min(int(max_new_tokens), 1536),
204
+ "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
205
+ "images": f'List of {len(state.get_images())} images: {all_image_hash}',
206
+ }
207
+ logger.info(f"==== request ====\n{pload}")
208
+
209
+ pload['images'] = state.get_images()
210
+
211
+ state.messages[-1][-1] = "▌"
212
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
213
+
214
+ try:
215
+ # Stream output
216
+ response = requests.post(worker_addr + "/worker_generate_stream",
217
+ headers=headers, json=pload, stream=True, timeout=10)
218
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
219
+ if chunk:
220
+ data = json.loads(chunk.decode())
221
+ if data["error_code"] == 0:
222
+ output = data["text"][len(prompt):].strip()
223
+ state.messages[-1][-1] = output + "▌"
224
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
225
+ else:
226
+ output = data["text"] + f" (error_code: {data['error_code']})"
227
+ state.messages[-1][-1] = output
228
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
229
+ return
230
+ time.sleep(0.03)
231
+ except requests.exceptions.RequestException as e:
232
+ state.messages[-1][-1] = server_error_msg
233
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
234
+ return
235
+
236
+ state.messages[-1][-1] = state.messages[-1][-1][:-1]
237
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
238
+
239
+ finish_tstamp = time.time()
240
+ logger.info(f"{output}")
241
+
242
+ with open(get_conv_log_filename(), "a") as fout:
243
+ data = {
244
+ "tstamp": round(finish_tstamp, 4),
245
+ "type": "chat",
246
+ "model": model_name,
247
+ "start": round(start_tstamp, 4),
248
+ "finish": round(finish_tstamp, 4),
249
+ "state": state.dict(),
250
+ "images": all_image_hash,
251
+ "ip": request.client.host,
252
+ }
253
+ fout.write(json.dumps(data) + "\n")
254
+
255
+ title_markdown = ("""
256
+ # CuMo: Scaling Multimodal LLM with Co-Upcycled Mixture-of-Experts
257
+ [[Project Page](https://chrisjuniorli.github.io/project/CuMo/)] [[Code](https://github.com/SHI-Labs/CuMo)] [[Model](https://huggingface.co/jiachenl/CuMo-mistral-7b)] | 📚 [[Arxiv]()]]
258
+ """)
259
+
260
+ tos_markdown = ("""
261
+ ### Terms of use
262
+ By using this service, users are required to agree to the following terms:
263
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
264
+ Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
265
+ For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
266
+ """)
267
+
268
+
269
+ learn_more_markdown = ("""
270
+ ### License
271
+ The service is a research preview intended for non-commercial use only, subject to the. Please contact us if you find any potential violation.
272
+ """)
273
+
274
+ block_css = """
275
+
276
+ #buttons button {
277
+ min-width: min(120px,100%);
278
+ }
279
+
280
+ """
281
+
282
+ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
283
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
284
+ with gr.Blocks(title="CuMo", theme=gr.themes.Default(), css=block_css) as demo:
285
+ state = gr.State()
286
+
287
+ if not embed_mode:
288
+ gr.Markdown(title_markdown)
289
+
290
+ with gr.Row():
291
+ with gr.Column(scale=3):
292
+ with gr.Row(elem_id="model_selector_row"):
293
+ model_selector = gr.Dropdown(
294
+ choices=models,
295
+ value=models[0] if len(models) > 0 else "",
296
+ interactive=True,
297
+ show_label=False,
298
+ container=False)
299
+
300
+ imagebox = gr.Image(type="pil")
301
+ image_process_mode = gr.Radio(
302
+ ["Crop", "Resize", "Pad", "Default"],
303
+ value="Default",
304
+ label="Preprocess for non-square image", visible=False)
305
+
306
+ if cur_dir is None:
307
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
308
+ gr.Examples(examples=[
309
+ [f"{cur_dir}/examples/aveger.jpg", "Can you introduce this movie based on the poster?"],
310
+ [f"{cur_dir}/examples/fridge.webp", "Can you describe what groceries are presented in this fridge?"],
311
+ [f"{cur_dir}/examples/su7_4.jpg", "What car is it in this image?"],
312
+ [f"{cur_dir}/examples/nvidia.jpeg", "Can you tell me what happened in this image?"],
313
+ [f"{cur_dir}/examples/animal.webp", "What animals are in this image?"],
314
+ [f"{cur_dir}/examples/disney.jpeg", "How many characters in this image?"],
315
+ [f"{cur_dir}/examples/reka_6.jpeg", "What colour is my hat (im sitting on the bear)?"],
316
+ ], inputs=[imagebox, textbox])
317
+
318
+ with gr.Accordion("Parameters", open=False) as parameter_row:
319
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
320
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
321
+ max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
322
+
323
+ with gr.Column(scale=8):
324
+ chatbot = gr.Chatbot(
325
+ elem_id="chatbot",
326
+ label="CuMo Chatbot",
327
+ height=650,
328
+ layout="panel",
329
+ )
330
+ with gr.Row():
331
+ with gr.Column(scale=8):
332
+ textbox.render()
333
+ with gr.Column(scale=1, min_width=50):
334
+ submit_btn = gr.Button(value="Send", variant="primary")
335
+ with gr.Row(elem_id="buttons") as button_row:
336
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
337
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
338
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
339
+ #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
340
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
341
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
342
+
343
+ if not embed_mode:
344
+ gr.Markdown(tos_markdown)
345
+ gr.Markdown(learn_more_markdown)
346
+ url_params = gr.JSON(visible=False)
347
+
348
+ # Register listeners
349
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
350
+ upvote_btn.click(
351
+ upvote_last_response,
352
+ [state, model_selector],
353
+ [textbox, upvote_btn, downvote_btn, flag_btn]
354
+ )
355
+ downvote_btn.click(
356
+ downvote_last_response,
357
+ [state, model_selector],
358
+ [textbox, upvote_btn, downvote_btn, flag_btn]
359
+ )
360
+ flag_btn.click(
361
+ flag_last_response,
362
+ [state, model_selector],
363
+ [textbox, upvote_btn, downvote_btn, flag_btn]
364
+ )
365
+
366
+ regenerate_btn.click(
367
+ regenerate,
368
+ [state, image_process_mode],
369
+ [state, chatbot, textbox, imagebox] + btn_list
370
+ ).then(
371
+ http_bot,
372
+ [state, model_selector, temperature, top_p, max_output_tokens],
373
+ [state, chatbot] + btn_list,
374
+ concurrency_limit=concurrency_count
375
+ )
376
+
377
+ clear_btn.click(
378
+ clear_history,
379
+ None,
380
+ [state, chatbot, textbox, imagebox] + btn_list,
381
+ queue=False
382
+ )
383
+
384
+ textbox.submit(
385
+ add_text,
386
+ [state, textbox, imagebox, image_process_mode],
387
+ [state, chatbot, textbox, imagebox] + btn_list,
388
+ queue=False
389
+ ).then(
390
+ http_bot,
391
+ [state, model_selector, temperature, top_p, max_output_tokens],
392
+ [state, chatbot] + btn_list,
393
+ concurrency_limit=concurrency_count
394
+ )
395
+
396
+ submit_btn.click(
397
+ add_text,
398
+ [state, textbox, imagebox, image_process_mode],
399
+ [state, chatbot, textbox, imagebox] + btn_list
400
+ ).then(
401
+ http_bot,
402
+ [state, model_selector, temperature, top_p, max_output_tokens],
403
+ [state, chatbot] + btn_list,
404
+ concurrency_limit=concurrency_count
405
+ )
406
+
407
+ if args.model_list_mode == "once":
408
+ demo.load(
409
+ load_demo,
410
+ [url_params],
411
+ [state, model_selector],
412
+ js=get_window_url_params
413
+ )
414
+ elif args.model_list_mode == "reload":
415
+ demo.load(
416
+ load_demo_refresh_model_list,
417
+ None,
418
+ [state, model_selector],
419
+ queue=False
420
+ )
421
+ else:
422
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
423
+
424
+ return demo
425
+
426
+
427
+ if __name__ == "__main__":
428
+ parser = argparse.ArgumentParser()
429
+ parser.add_argument("--host", type=str, default="0.0.0.0")
430
+ parser.add_argument("--port", type=int)
431
+ parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
432
+ parser.add_argument("--concurrency-count", type=int, default=16)
433
+ parser.add_argument("--model-list-mode", type=str, default="once",
434
+ choices=["once", "reload"])
435
+ parser.add_argument("--share", action="store_true")
436
+ parser.add_argument("--moderate", action="store_true")
437
+ parser.add_argument("--embed", action="store_true")
438
+ args = parser.parse_args()
439
+ logger.info(f"args: {args}")
440
+
441
+ models = get_model_list()
442
+
443
+ logger.info(args)
444
+ demo = build_demo(args.embed, concurrency_count=args.concurrency_count)
445
+ demo.queue(
446
+ api_open=False
447
+ ).launch(
448
+ server_name=args.host,
449
+ server_port=args.port,
450
+ share=args.share
451
+ )
cumo/serve/model_worker.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A model worker executes the model.
3
+ """
4
+ import argparse
5
+ import asyncio
6
+ import json
7
+ import time
8
+ import threading
9
+ import uuid
10
+
11
+ from fastapi import FastAPI, Request, BackgroundTasks
12
+ from fastapi.responses import StreamingResponse
13
+ import requests
14
+ import torch
15
+ import uvicorn
16
+ from functools import partial
17
+ #import spaces
18
+
19
+ from cumo.constants import WORKER_HEART_BEAT_INTERVAL
20
+ from cumo.utils import (build_logger, server_error_msg,
21
+ pretty_print_semaphore)
22
+ from cumo.model.builder import load_pretrained_model
23
+ from cumo.mm_utils import process_images, load_image_from_base64, tokenizer_image_token
24
+ from cumo.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
25
+ from transformers import TextIteratorStreamer
26
+ from threading import Thread
27
+
28
+
29
+ GB = 1 << 30
30
+
31
+ worker_id = str(uuid.uuid4())[:6]
32
+ logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
33
+ global_counter = 0
34
+
35
+ model_semaphore = None
36
+
37
+
38
+ def heart_beat_worker(controller):
39
+
40
+ while True:
41
+ time.sleep(WORKER_HEART_BEAT_INTERVAL)
42
+ controller.send_heart_beat()
43
+
44
+
45
+ class ModelWorker:
46
+ def __init__(self, controller_addr, worker_addr,
47
+ worker_id, no_register,
48
+ model_path, model_base, model_name,
49
+ load_8bit, load_4bit, device, use_flash_attn=False):
50
+ self.controller_addr = controller_addr
51
+ self.worker_addr = worker_addr
52
+ self.worker_id = worker_id
53
+ if model_path.endswith("/"):
54
+ model_path = model_path[:-1]
55
+ if model_name is None:
56
+ model_paths = model_path.split("/")
57
+ if model_paths[-1].startswith('checkpoint-'):
58
+ self.model_name = model_paths[-2] + "_" + model_paths[-1]
59
+ else:
60
+ self.model_name = model_paths[-1]
61
+ else:
62
+ self.model_name = model_name
63
+
64
+ self.device = device
65
+ logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
66
+ self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
67
+ model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device, use_flash_attn=use_flash_attn)
68
+
69
+ self.model.config.training = False
70
+ self.is_multimodal = 'llava' in self.model_name.lower() or 'cumo' in self.model_name.lower()
71
+
72
+ if not no_register:
73
+ self.register_to_controller()
74
+ self.heart_beat_thread = threading.Thread(
75
+ target=heart_beat_worker, args=(self,), daemon=True)
76
+ self.heart_beat_thread.start()
77
+
78
+ def register_to_controller(self):
79
+ logger.info("Register to controller")
80
+
81
+ url = self.controller_addr + "/register_worker"
82
+ data = {
83
+ "worker_name": self.worker_addr,
84
+ "check_heart_beat": True,
85
+ "worker_status": self.get_status()
86
+ }
87
+ r = requests.post(url, json=data)
88
+ assert r.status_code == 200
89
+
90
+ def send_heart_beat(self):
91
+ logger.info(f"Send heart beat. Models: {[self.model_name]}. "
92
+ f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
93
+ f"global_counter: {global_counter}")
94
+
95
+ url = self.controller_addr + "/receive_heart_beat"
96
+
97
+ while True:
98
+ try:
99
+ ret = requests.post(url, json={
100
+ "worker_name": self.worker_addr,
101
+ "queue_length": self.get_queue_length()}, timeout=5)
102
+ exist = ret.json()["exist"]
103
+ break
104
+ except requests.exceptions.RequestException as e:
105
+ logger.error(f"heart beat error: {e}")
106
+ time.sleep(5)
107
+
108
+ if not exist:
109
+ self.register_to_controller()
110
+
111
+ def get_queue_length(self):
112
+ if model_semaphore is None:
113
+ return 0
114
+ else:
115
+ return args.limit_model_concurrency - model_semaphore._value + (len(
116
+ model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
117
+
118
+ def get_status(self):
119
+ return {
120
+ "model_names": [self.model_name],
121
+ "speed": 1,
122
+ "queue_length": self.get_queue_length(),
123
+ }
124
+
125
+ @spaces.GPU
126
+ @torch.inference_mode()
127
+ def generate_stream(self, params):
128
+ tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
129
+
130
+ prompt = params["prompt"]
131
+ ori_prompt = prompt
132
+ images = params.get("images", None)
133
+ num_image_tokens = 0
134
+ if images is not None and len(images) > 0 and self.is_multimodal:
135
+ if len(images) > 0:
136
+ if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
137
+ raise ValueError("Number of images does not match number of <image> tokens in prompt")
138
+
139
+ images = [load_image_from_base64(image) for image in images]
140
+ image_sizes = [image.size for image in images]
141
+ images = process_images(images, image_processor, model.config)
142
+
143
+ if type(images) is list:
144
+ images = [image.to(self.model.device, dtype=torch.float16) for image in images]
145
+ else:
146
+ images = images.to(self.model.device, dtype=torch.float16)
147
+
148
+ replace_token = DEFAULT_IMAGE_TOKEN
149
+ if getattr(self.model.config, 'mm_use_im_start_end', False):
150
+ replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
151
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
152
+
153
+ num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
154
+ else:
155
+ images = None
156
+ image_sizes = None
157
+ image_args = {"images": images, "image_sizes": image_sizes}
158
+ else:
159
+ images = None
160
+ image_args = {}
161
+
162
+ temperature = float(params.get("temperature", 1.0))
163
+ top_p = float(params.get("top_p", 1.0))
164
+ max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
165
+ max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
166
+ stop_str = params.get("stop", None)
167
+ do_sample = True if temperature > 0.001 else False
168
+
169
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
170
+ keywords = [stop_str]
171
+ # stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
172
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
173
+
174
+ max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
175
+
176
+ if max_new_tokens < 1:
177
+ yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
178
+ return
179
+
180
+ thread = Thread(target=model.generate, kwargs=dict(
181
+ inputs=input_ids,
182
+ do_sample=do_sample,
183
+ temperature=temperature,
184
+ top_p=top_p,
185
+ max_new_tokens=max_new_tokens,
186
+ streamer=streamer,
187
+ use_cache=True,
188
+ pad_token_id=tokenizer.eos_token_id,
189
+ **image_args
190
+ ))
191
+ thread.start()
192
+
193
+ generated_text = ori_prompt
194
+ for new_text in streamer:
195
+ generated_text += new_text
196
+ if generated_text.endswith(stop_str):
197
+ generated_text = generated_text[:-len(stop_str)]
198
+ yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
199
+
200
+ def generate_stream_gate(self, params):
201
+ try:
202
+ for x in self.generate_stream(params):
203
+ yield x
204
+ except ValueError as e:
205
+ print("Caught ValueError:", e)
206
+ ret = {
207
+ "text": server_error_msg,
208
+ "error_code": 1,
209
+ }
210
+ yield json.dumps(ret).encode() + b"\0"
211
+ except torch.cuda.CudaError as e:
212
+ print("Caught torch.cuda.CudaError:", e)
213
+ ret = {
214
+ "text": server_error_msg,
215
+ "error_code": 1,
216
+ }
217
+ yield json.dumps(ret).encode() + b"\0"
218
+ except Exception as e:
219
+ print("Caught Unknown Error", e)
220
+ ret = {
221
+ "text": server_error_msg,
222
+ "error_code": 1,
223
+ }
224
+ yield json.dumps(ret).encode() + b"\0"
225
+
226
+
227
+ app = FastAPI()
228
+
229
+
230
+ def release_model_semaphore(fn=None):
231
+ model_semaphore.release()
232
+ if fn is not None:
233
+ fn()
234
+
235
+
236
+ @app.post("/worker_generate_stream")
237
+ async def generate_stream(request: Request):
238
+ global model_semaphore, global_counter
239
+ global_counter += 1
240
+ params = await request.json()
241
+
242
+ if model_semaphore is None:
243
+ model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
244
+ await model_semaphore.acquire()
245
+ worker.send_heart_beat()
246
+ generator = worker.generate_stream_gate(params)
247
+ background_tasks = BackgroundTasks()
248
+ background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
249
+ return StreamingResponse(generator, background=background_tasks)
250
+
251
+
252
+ @app.post("/worker_get_status")
253
+ async def get_status(request: Request):
254
+ return worker.get_status()
255
+
256
+
257
+ if __name__ == "__main__":
258
+ parser = argparse.ArgumentParser()
259
+ parser.add_argument("--host", type=str, default="localhost")
260
+ parser.add_argument("--port", type=int, default=21002)
261
+ parser.add_argument("--worker-address", type=str,
262
+ default="http://localhost:21002")
263
+ parser.add_argument("--controller-address", type=str,
264
+ default="http://localhost:21001")
265
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
266
+ parser.add_argument("--model-base", type=str, default=None)
267
+ parser.add_argument("--model-name", type=str)
268
+ parser.add_argument("--device", type=str, default="cuda")
269
+ parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
270
+ parser.add_argument("--limit-model-concurrency", type=int, default=5)
271
+ parser.add_argument("--stream-interval", type=int, default=1)
272
+ parser.add_argument("--no-register", action="store_true")
273
+ parser.add_argument("--load-8bit", action="store_true")
274
+ parser.add_argument("--load-4bit", action="store_true")
275
+ parser.add_argument("--use-flash-attn", action="store_true")
276
+ args = parser.parse_args()
277
+ logger.info(f"args: {args}")
278
+
279
+ if args.multi_modal:
280
+ logger.warning("Multimodal mode is automatically detected with model name, please make sure `cumo` is included in the model path.")
281
+
282
+ worker = ModelWorker(args.controller_address,
283
+ args.worker_address,
284
+ worker_id,
285
+ args.no_register,
286
+ args.model_path,
287
+ args.model_base,
288
+ args.model_name,
289
+ args.load_8bit,
290
+ args.load_4bit,
291
+ args.device,
292
+ use_flash_attn=args.use_flash_attn)
293
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
cumo/serve/register_worker.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Manually register workers.
3
+
4
+ Usage:
5
+ python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
6
+ """
7
+
8
+ import argparse
9
+
10
+ import requests
11
+
12
+ if __name__ == "__main__":
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument("--controller-address", type=str)
15
+ parser.add_argument("--worker-name", type=str)
16
+ parser.add_argument("--check-heart-beat", action="store_true")
17
+ args = parser.parse_args()
18
+
19
+ url = args.controller_address + "/register_worker"
20
+ data = {
21
+ "worker_name": args.worker_name,
22
+ "check_heart_beat": args.check_heart_beat,
23
+ "worker_status": None,
24
+ }
25
+ r = requests.post(url, json=data)
26
+ assert r.status_code == 200
cumo/serve/sglang_worker.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A model worker executes the model.
3
+ """
4
+ import argparse
5
+ import asyncio
6
+ from concurrent.futures import ThreadPoolExecutor
7
+ import json
8
+ import time
9
+ import threading
10
+ import uuid
11
+
12
+ from fastapi import FastAPI, Request, BackgroundTasks
13
+ from fastapi.responses import StreamingResponse
14
+ import requests
15
+ import re
16
+ import uvicorn
17
+ from functools import partial
18
+
19
+ from cumo.constants import WORKER_HEART_BEAT_INTERVAL
20
+ from cumo.utils import (build_logger, server_error_msg,
21
+ pretty_print_semaphore)
22
+ from cumo.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, expand2square
23
+ from cumo.constants import DEFAULT_IMAGE_TOKEN
24
+
25
+ import sglang as sgl
26
+ from sglang.backend.runtime_endpoint import RuntimeEndpoint
27
+
28
+
29
+ GB = 1 << 30
30
+
31
+ worker_id = str(uuid.uuid4())[:6]
32
+ logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
33
+ global_counter = 0
34
+
35
+ model_semaphore = None
36
+
37
+
38
+ def heart_beat_worker(controller):
39
+ while True:
40
+ time.sleep(WORKER_HEART_BEAT_INTERVAL)
41
+ controller.send_heart_beat()
42
+
43
+
44
+ @sgl.function
45
+ def pipeline(s, prompt, max_tokens):
46
+ for p in prompt:
47
+ if type(p) is str:
48
+ s += p
49
+ else:
50
+ s += sgl.image(p)
51
+ s += sgl.gen("response", max_tokens=max_tokens)
52
+
53
+
54
+ class ModelWorker:
55
+ def __init__(self, controller_addr, worker_addr, sgl_endpoint,
56
+ worker_id, no_register, model_name):
57
+ self.controller_addr = controller_addr
58
+ self.worker_addr = worker_addr
59
+ self.worker_id = worker_id
60
+
61
+ # Select backend
62
+ backend = RuntimeEndpoint(sgl_endpoint)
63
+ sgl.set_default_backend(backend)
64
+ model_path = backend.model_info["model_path"]
65
+
66
+ if model_path.endswith("/"):
67
+ model_path = model_path[:-1]
68
+ if model_name is None:
69
+ model_paths = model_path.split("/")
70
+ if model_paths[-1].startswith('checkpoint-'):
71
+ self.model_name = model_paths[-2] + "_" + model_paths[-1]
72
+ else:
73
+ self.model_name = model_paths[-1]
74
+ else:
75
+ self.model_name = model_name
76
+
77
+ logger.info(f"Loading the SGLANG model {self.model_name} on worker {worker_id} ...")
78
+
79
+ if not no_register:
80
+ self.register_to_controller()
81
+ self.heart_beat_thread = threading.Thread(
82
+ target=heart_beat_worker, args=(self,), daemon=True)
83
+ self.heart_beat_thread.start()
84
+
85
+ def register_to_controller(self):
86
+ logger.info("Register to controller")
87
+
88
+ url = self.controller_addr + "/register_worker"
89
+ data = {
90
+ "worker_name": self.worker_addr,
91
+ "check_heart_beat": True,
92
+ "worker_status": self.get_status()
93
+ }
94
+ r = requests.post(url, json=data)
95
+ assert r.status_code == 200
96
+
97
+ def send_heart_beat(self):
98
+ logger.info(f"Send heart beat. Models: {[self.model_name]}. "
99
+ f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
100
+ f"global_counter: {global_counter}")
101
+
102
+ url = self.controller_addr + "/receive_heart_beat"
103
+
104
+ while True:
105
+ try:
106
+ ret = requests.post(url, json={
107
+ "worker_name": self.worker_addr,
108
+ "queue_length": self.get_queue_length()}, timeout=5)
109
+ exist = ret.json()["exist"]
110
+ break
111
+ except requests.exceptions.RequestException as e:
112
+ logger.error(f"heart beat error: {e}")
113
+ time.sleep(5)
114
+
115
+ if not exist:
116
+ self.register_to_controller()
117
+
118
+ def get_queue_length(self):
119
+ if model_semaphore is None:
120
+ return 0
121
+ else:
122
+ return args.limit_model_concurrency - model_semaphore._value + (len(
123
+ model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
124
+
125
+ def get_status(self):
126
+ return {
127
+ "model_names": [self.model_name],
128
+ "speed": 1,
129
+ "queue_length": self.get_queue_length(),
130
+ }
131
+
132
+ async def generate_stream(self, params):
133
+ ori_prompt = prompt = params["prompt"]
134
+ images = params.get("images", None)
135
+ if images is not None and len(images) > 0:
136
+ if len(images) > 0:
137
+ if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
138
+ raise ValueError("Number of images does not match number of <image> tokens in prompt")
139
+
140
+ images = [load_image_from_base64(image) for image in images]
141
+
142
+ # FIXME: for image-start/end token
143
+ # replace_token = DEFAULT_IMAGE_TOKEN
144
+ # if getattr(self.model.config, 'mm_use_im_start_end', False):
145
+ # replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
146
+ # prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
147
+ prompt = prompt.replace(' ' + DEFAULT_IMAGE_TOKEN + '\n', DEFAULT_IMAGE_TOKEN)
148
+ prompt_split = prompt.split(DEFAULT_IMAGE_TOKEN)
149
+ prompt = []
150
+ for i in range(len(prompt_split)):
151
+ prompt.append(prompt_split[i])
152
+ if i < len(images):
153
+ prompt.append(images[i])
154
+ else:
155
+ prompt = [prompt]
156
+
157
+ temperature = float(params.get("temperature", 1.0))
158
+ top_p = float(params.get("top_p", 1.0))
159
+ # max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
160
+ max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
161
+ stop_str = params.get("stop", None)
162
+ stop_str = [stop_str] if stop_str is not None else None
163
+
164
+ print({'prompt': prompt, 'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p})
165
+ state = pipeline.run(prompt, max_new_tokens, temperature=temperature, top_p=top_p, stream=True)
166
+
167
+ generated_text = ori_prompt
168
+ async for text_outputs in state.text_async_iter(var_name="response"):
169
+ generated_text += text_outputs
170
+ yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
171
+
172
+ async def generate_stream_gate(self, params):
173
+ try:
174
+ async for x in self.generate_stream(params):
175
+ yield x
176
+ except ValueError as e:
177
+ print("Caught ValueError:", e)
178
+ ret = {
179
+ "text": server_error_msg,
180
+ "error_code": 1,
181
+ }
182
+ yield json.dumps(ret).encode() + b"\0"
183
+ except Exception as e:
184
+ print("Caught Unknown Error", e)
185
+ ret = {
186
+ "text": server_error_msg,
187
+ "error_code": 1,
188
+ }
189
+ yield json.dumps(ret).encode() + b"\0"
190
+
191
+
192
+ app = FastAPI()
193
+
194
+
195
+ def release_model_semaphore(fn=None):
196
+ model_semaphore.release()
197
+ if fn is not None:
198
+ fn()
199
+
200
+
201
+ @app.post("/worker_generate_stream")
202
+ async def generate_stream(request: Request):
203
+ global model_semaphore, global_counter
204
+ global_counter += 1
205
+ params = await request.json()
206
+
207
+ if model_semaphore is None:
208
+ model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
209
+ await model_semaphore.acquire()
210
+ worker.send_heart_beat()
211
+ generator = worker.generate_stream_gate(params)
212
+ background_tasks = BackgroundTasks()
213
+ background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
214
+ return StreamingResponse(generator, background=background_tasks)
215
+
216
+
217
+ @app.post("/worker_get_status")
218
+ async def get_status(request: Request):
219
+ return worker.get_status()
220
+
221
+
222
+ if __name__ == "__main__":
223
+ parser = argparse.ArgumentParser()
224
+ parser.add_argument("--host", type=str, default="localhost")
225
+ parser.add_argument("--port", type=int, default=21002)
226
+ parser.add_argument("--worker-address", type=str,
227
+ default="http://localhost:21002")
228
+ parser.add_argument("--controller-address", type=str,
229
+ default="http://localhost:21001")
230
+ parser.add_argument("--model-name", type=str)
231
+ parser.add_argument("--sgl-endpoint", type=str)
232
+ parser.add_argument("--limit-model-concurrency", type=int, default=5)
233
+ parser.add_argument("--stream-interval", type=int, default=1)
234
+ parser.add_argument("--no-register", action="store_true")
235
+ args = parser.parse_args()
236
+ logger.info(f"args: {args}")
237
+
238
+ worker = ModelWorker(args.controller_address,
239
+ args.worker_address,
240
+ args.sgl_endpoint,
241
+ worker_id,
242
+ args.no_register,
243
+ args.model_name)
244
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")