OpenSourceRonin commited on
Commit
746ca46
·
1 Parent(s): f9e7dbf

build with model selection

Browse files
Files changed (2) hide show
  1. app.py +180 -54
  2. requirements.txt +3 -1
app.py CHANGED
@@ -1,39 +1,157 @@
1
  import spaces
2
- import gradio as gr
3
- from huggingface_hub import InferenceClient
 
4
 
5
- from vptq.app_utils import get_chat_loop_generator
 
6
 
7
- # Update model list with annotations
8
- model_list_with_annotations = {
9
- # "VPTQ-community/Meta-Llama-3.1-70B-Instruct-v8-k65536-65536-woft": "Llama 3.1 70B @ 4bit",
10
- # "VPTQ-community/Meta-Llama-3.1-70B-Instruct-v8-k65536-256-woft": "Llama 3.1 70B @ 3bit",
11
- # "VPTQ-community/Meta-Llama-3.1-70B-Instruct-v16-k65536-65536-woft": "Llama 3.1 70B @ 2bit",
12
- # "VPTQ-community/Qwen2.5-72B-Instruct-v8-k65536-65536-woft": "Qwen2.5 72B @ 4 bits",
13
- # "VPTQ-community/Qwen2.5-72B-Instruct-v8-k65536-256-woft": "Qwen2.5 72B @ 3 bits",
14
- # "VPTQ-community/Qwen2.5-72B-Instruct-v16-k65536-65536-woft": "Qwen2.5 72B @ 3 bits",
15
- # "VPTQ-community/Qwen2.5-32B-Instruct-v8-k65536-65536-woft": "Qwen2.5 32B @ 4 bits",
16
- "VPTQ-community/Qwen2.5-32B-Instruct-v8-k65536-256-woft": "Qwen2.5 32B @ 3 bits",
17
- "VPTQ-community/Qwen2.5-32B-Instruct-v16-k65536-0-woft": "Qwen2.5 32B @ 2 bits"
18
- }
19
-
20
- # Create a list of choices with annotations for the dropdown
21
- model_list_with_annotations_display = [f"{key} ({value})" for key, value in model_list_with_annotations.items()]
22
-
23
- model_keys = list(model_list_with_annotations.keys())
24
- current_model_g = model_keys[0]
25
- chat_completion = get_chat_loop_generator(current_model_g)
26
 
27
- @spaces.GPU
28
- def update_title_and_chatmodel(model):
29
- model = str(model)
30
- global chat_completion
31
- global current_model_g
32
- if model != current_model_g:
33
- current_model_g = model
34
- chat_completion = get_chat_loop_generator(current_model_g)
35
- return model
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  @spaces.GPU
39
  def respond(
@@ -43,7 +161,17 @@ def respond(
43
  max_tokens,
44
  temperature,
45
  top_p,
 
46
  ):
 
 
 
 
 
 
 
 
 
47
  messages = [{"role": "system", "content": system_message}]
48
 
49
  for val in history:
@@ -69,23 +197,21 @@ def respond(
69
  yield response
70
 
71
 
72
- css = """
73
- h1 {
74
- text-align: center;
75
- display: block;
76
- }
77
  """
 
 
 
 
 
 
 
 
78
 
79
- chatbot = gr.Chatbot(label="Gradio ChatInterface")
80
- with gr.Blocks() as demo:
81
- with gr.Column(scale=1):
82
- title_output = gr.Markdown("Please select a model to run")
83
- chat_demo = gr.ChatInterface(
84
  respond,
85
- additional_inputs_accordion=gr.Accordion(
86
- label="⚙️ Parameters", open=False, render=False
87
- ),
88
- fill_height=False,
89
  additional_inputs=[
90
  gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
91
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
@@ -95,17 +221,17 @@ with gr.Blocks() as demo:
95
  maximum=1.0,
96
  value=0.95,
97
  step=0.05,
98
- label="Top-p (nucleus sampling)"
 
 
 
 
 
99
  ),
100
  ],
101
  )
102
- model_select = gr.Dropdown(
103
- choices=model_list_with_annotations_display,
104
- label="Models",
105
- value=model_list_with_annotations_display[0],
106
- info="Model & Estimated Quantized Bitwidth"
107
- )
108
- model_select.change(update_title_and_chatmodel, inputs=[model_select], outputs=title_output)
109
 
110
  if __name__ == "__main__":
111
- demo.launch()
 
 
 
1
  import spaces
2
+ import os
3
+ import threading
4
+ from collections import deque
5
 
6
+ import plotly.graph_objs as go
7
+ import pynvml
8
 
9
+ import gradio as gr
10
+ from huggingface_hub import snapshot_download
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ from vptq.app_utils import get_chat_loop_generator
 
 
 
 
 
 
 
 
13
 
14
+ models = [
15
+ {
16
+ "name": "VPTQ-community/Meta-Llama-3.1-8B-Instruct-v8-k65536-65536-woft",
17
+ "bits": "4 bits"
18
+ },
19
+ {
20
+ "name": "VPTQ-community/Meta-Llama-3.1-8B-Instruct-v8-k65536-256-woft",
21
+ "bits": "3 bits"
22
+ },
23
+ ]
24
+
25
+ # Queues for storing historical data (saving the last 100 GPU utilization and memory usage values)
26
+ gpu_util_history = deque(maxlen=100)
27
+ mem_usage_history = deque(maxlen=100)
28
+
29
+
30
+ def initialize_nvml():
31
+ """
32
+ Initialize NVML (NVIDIA Management Library).
33
+ """
34
+ pynvml.nvmlInit()
35
+
36
+
37
+ def get_gpu_info():
38
+ """
39
+ Get GPU utilization and memory usage information.
40
+
41
+ Returns:
42
+ dict: A dictionary containing GPU utilization and memory usage information.
43
+ """
44
+ handle = pynvml.nvmlDeviceGetHandleByIndex(0) # Assuming a single GPU setup
45
+ utilization = pynvml.nvmlDeviceGetUtilizationRates(handle)
46
+ memory = pynvml.nvmlDeviceGetMemoryInfo(handle)
47
+
48
+ gpu_info = {
49
+ 'gpu_util': utilization.gpu,
50
+ 'mem_used': memory.used / 1024**2, # Convert bytes to MiB
51
+ 'mem_total': memory.total / 1024**2, # Convert bytes to MiB
52
+ 'mem_percent': (memory.used / memory.total) * 100
53
+ }
54
+ return gpu_info
55
+
56
+
57
+ def update_charts(chart_height: int = 200) -> go.Figure:
58
+ """
59
+ Update the GPU utilization and memory usage charts.
60
+
61
+ Args:
62
+ chart_height (int, optional): used to set the height of the chart. Defaults to 200.
63
+
64
+ Returns:
65
+ plotly.graph_objs.Figure: The updated figure containing the GPU and memory usage charts.
66
+ """
67
+ # obtain GPU information
68
+ gpu_info = get_gpu_info()
69
+
70
+ # records the latest GPU utilization and memory usage values
71
+ gpu_util = round(gpu_info.get('gpu_util', 0), 1)
72
+ mem_used = round(gpu_info.get('mem_used', 0) / 1024, 2) # Convert MiB to GiB
73
+ gpu_util_history.append(gpu_util)
74
+ mem_usage_history.append(mem_used)
75
+
76
+ # create GPU utilization line chart
77
+ gpu_trace = go.Scatter(
78
+ y=list(gpu_util_history),
79
+ mode='lines+markers',
80
+ text=list(gpu_util_history),
81
+ line=dict(shape='spline', color='blue'), # Make the line smooth and set color
82
+ yaxis='y1' # Link to y-axis 1
83
+ )
84
+
85
+ # create memory usage line chart
86
+ mem_trace = go.Scatter(
87
+ y=list(mem_usage_history),
88
+ mode='lines+markers',
89
+ text=list(mem_usage_history),
90
+ line=dict(shape='spline', color='red'), # Make the line smooth and set color
91
+ yaxis='y2' # Link to y-axis 2
92
+ )
93
+
94
+ # set the layout of the chart
95
+ layout = go.Layout(
96
+ xaxis=dict(title=None, showticklabels=False, ticks=''),
97
+ yaxis=dict(
98
+ title='GPU Utilization (%)',
99
+ range=[-5, 110],
100
+ titlefont=dict(color='blue'),
101
+ tickfont=dict(color='blue'),
102
+ ),
103
+ yaxis2=dict(title='Memory Usage (GiB)',
104
+ range=[0, max(24,
105
+ max(mem_usage_history) + 1)],
106
+ titlefont=dict(color='red'),
107
+ tickfont=dict(color='red'),
108
+ overlaying='y',
109
+ side='right'),
110
+ height=chart_height, # set the height of the chart
111
+ margin=dict(l=10, r=10, t=0, b=0), # set the margin of the chart
112
+ showlegend=False # disable the legend
113
+ )
114
+
115
+ fig = go.Figure(data=[gpu_trace, mem_trace], layout=layout)
116
+ return fig
117
+
118
+
119
+ def initialize_history():
120
+ """
121
+ Initializes the GPU utilization and memory usage history.
122
+ """
123
+ for _ in range(100):
124
+ gpu_info = get_gpu_info()
125
+ gpu_util_history.append(round(gpu_info.get('gpu_util', 0), 1))
126
+ mem_usage_history.append(round(gpu_info.get('mem_percent', 0), 1))
127
+
128
+
129
+ def enable_gpu_info():
130
+ pynvml.nvmlInit()
131
+
132
+
133
+ def disable_gpu_info():
134
+ pynvml.nvmlShutdown()
135
+
136
+ model_choices = [f"{model['name']} ({model['bits']})" for model in models]
137
+ display_to_model = {f"{model['name']} ({model['bits']})": model['name'] for model in models}
138
+
139
+
140
+ def download_model(model):
141
+ print(f"Downloading {model['name']}...")
142
+ snapshot_download(repo_id=model['name'])
143
+
144
+
145
+ def download_models_in_background():
146
+ print('Downloading models for the first time...')
147
+ for model in models:
148
+ download_model(model)
149
+
150
+
151
+ download_thread = threading.Thread(target=download_models_in_background)
152
+ download_thread.start()
153
+
154
+ loaded_models = {}
155
 
156
  @spaces.GPU
157
  def respond(
 
161
  max_tokens,
162
  temperature,
163
  top_p,
164
+ selected_model_display_label,
165
  ):
166
+ model_name = display_to_model[selected_model_display_label]
167
+
168
+ # Check if the model is already loaded
169
+ if model_name not in loaded_models:
170
+ # Load and store the model in the cache
171
+ loaded_models[model_name] = get_chat_loop_generator(model_name)
172
+
173
+ chat_completion = loaded_models[model_name]
174
+
175
  messages = [{"role": "system", "content": system_message}]
176
 
177
  for val in history:
 
197
  yield response
198
 
199
 
 
 
 
 
 
200
  """
201
+ For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
202
+ """
203
+ # enable_gpu_info()
204
+ with gr.Blocks(fill_height=True) as demo:
205
+ with gr.Row():
206
+
207
+ def update_chart():
208
+ return _update_charts(chart_height=200)
209
 
210
+ gpu_chart = gr.Plot(update_chart, every=0.1) # update every 0.1 seconds
211
+
212
+ with gr.Column():
213
+ chat_interface = gr.ChatInterface(
 
214
  respond,
 
 
 
 
215
  additional_inputs=[
216
  gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
217
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
 
221
  maximum=1.0,
222
  value=0.95,
223
  step=0.05,
224
+ label="Top-p (nucleus sampling)",
225
+ ),
226
+ gr.Dropdown(
227
+ choices=model_choices,
228
+ value=model_choices[0],
229
+ label="Select Model",
230
  ),
231
  ],
232
  )
 
 
 
 
 
 
 
233
 
234
  if __name__ == "__main__":
235
+ share = os.getenv("SHARE_LINK", None) in ["1", "true", "True"]
236
+ demo.launch(share=share)
237
+ # disable_gpu_info()
requirements.txt CHANGED
@@ -1,2 +1,4 @@
1
  huggingface_hub>=0.22.2
2
- https://github.com/microsoft/VPTQ/releases/download/v0.0.1/vptq-0.0.1-cp310-cp310-manylinux1_x86_64.whl
 
 
 
1
  huggingface_hub>=0.22.2
2
+ https://github.com/microsoft/VPTQ/releases/download/v0.0.2.post1/vptq-0.0.2.post1-cp310-cp310-manylinux1_x86_64.whl
3
+ pynvml==11.5.3
4
+ plotly==5.24.1