sergey21000 commited on
Commit
a8e6cef
·
verified ·
1 Parent(s): fd7e2c7

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +259 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from shutil import rmtree
3
+ from typing import Union, List, Dict, Tuple, Optional
4
+ from tqdm import tqdm
5
+
6
+ import requests
7
+ import gradio as gr
8
+ from llama_cpp import Llama
9
+
10
+
11
+ # ================== ANNOTATIONS ========================
12
+
13
+ CHAT_HISTORY = List[Tuple[Optional[str], Optional[str]]]
14
+ MODEL_DICT = Dict[str, Llama]
15
+
16
+
17
+ # ================== FUNCS =============================
18
+
19
+ def download_file(file_url: str, file_path: Union[str, Path]) -> None:
20
+ response = requests.get(file_url, stream=True)
21
+ if response.status_code != 200:
22
+ raise Exception(f'Файл недоступен для скачивания по ссылке: {file_url}')
23
+ total_size = int(response.headers.get('content-length', 0))
24
+ progress_tqdm = tqdm(desc='Loading GGUF file', total=total_size, unit='iB', unit_scale=True)
25
+ progress_gradio = gr.Progress()
26
+ completed_size = 0
27
+ with open(file_path, 'wb') as file:
28
+ for data in response.iter_content(chunk_size=4096):
29
+ size = file.write(data)
30
+ progress_tqdm.update(size)
31
+ completed_size += size
32
+ desc = f'Loading GGUF file, {completed_size/1024**3:.3f}/{total_size/1024**3:.3f} GB'
33
+ progress_gradio(completed_size/total_size, desc=desc)
34
+
35
+
36
+ def download_gguf_and_init_model(gguf_url: str, model_dict: MODEL_DICT) -> Tuple[MODEL_DICT, bool, str]:
37
+ log = ''
38
+ if not gguf_url.endswith('.gguf'):
39
+ log += f'The link must be a direct link to the GGUF file\n'
40
+ return model_dict, log
41
+
42
+ gguf_filename = gguf_url.rsplit('/')[-1]
43
+ model_path = MODELS_PATH / gguf_filename
44
+ progress = gr.Progress()
45
+
46
+ if not model_path.is_file():
47
+ progress(0.3, desc='Шаг 1/2: Loading GGUF model file')
48
+ try:
49
+ download_file(gguf_url, model_path)
50
+ log += f'Model file {gguf_filename} successfully loaded\n'
51
+ except Exception as ex:
52
+ log += f'Error loading model from link {gguf_url}, error code:\n{ex}\n'
53
+ curr_model = model_dict.get('model')
54
+ if curr_model is None:
55
+ log += f'Model is missing from dictionary "model_dict"\n'
56
+ return model_dict, load_log
57
+ curr_model_filename = Path(curr_model.model_path).name
58
+ log += f'Current initialized model: {curr_model_filename}\n'
59
+ return model_dict, log
60
+ else:
61
+ log += f'Model file {gguf_filename} loaded, initializing model...\n'
62
+
63
+ progress(0.7, desc='Шаг 2/2: Model initialization')
64
+ model = Llama(model_path=str(model_path), n_gpu_layers=-1, verbose=True)
65
+ model_dict = {'model': model}
66
+ support_system_role = 'System role not supported' not in model.metadata['tokenizer.chat_template']
67
+ log += f'Model {gguf_filename} initialized\n'
68
+ return model_dict, support_system_role, log
69
+
70
+
71
+ def user_message_to_chatbot(user_message: str, chatbot: CHAT_HISTORY) -> Tuple[str, CHAT_HISTORY]:
72
+ if user_message:
73
+ chatbot.append((user_message, None))
74
+ return '', chatbot
75
+
76
+
77
+ def bot_response_to_chatbot(
78
+ chatbot: CHAT_HISTORY,
79
+ model_dict: MODEL_DICT,
80
+ system_prompt: str,
81
+ support_system_role: bool,
82
+ history_len: int,
83
+ do_sample: bool,
84
+ *generate_args,
85
+ ):
86
+
87
+ model = model_dict.get('model')
88
+ user_message = chatbot[-1][0]
89
+ messages = []
90
+
91
+ gen_kwargs = dict(zip(GENERATE_KWARGS.keys(), generate_args))
92
+ gen_kwargs['top_k'] = int(gen_kwargs['top_k'])
93
+
94
+ if not do_sample:
95
+ gen_kwargs['top_p'] = 0.0
96
+ gen_kwargs['top_k'] = 1
97
+ gen_kwargs['repeat_penalty'] = 1.0
98
+
99
+ if support_system_role and system_prompt:
100
+ messages.append({'role': 'system', 'content': system_prompt})
101
+
102
+ if history_len != 0:
103
+ for user_msg, bot_msg in chatbot[:-1][-history_len:]:
104
+ print(user_msg, bot_msg)
105
+ messages.append({'role': 'user', 'content': user_msg})
106
+ messages.append({'role': 'assistant', 'content': bot_msg})
107
+
108
+ messages.append({'role': 'user', 'content': user_message})
109
+ stream_response = model.create_chat_completion(
110
+ messages=messages,
111
+ stream=True,
112
+ **gen_kwargs,
113
+ )
114
+
115
+ chatbot[-1][1] = ''
116
+ for chunk in stream_response:
117
+ token = chunk['choices'][0]['delta'].get('content')
118
+ if token is not None:
119
+ chatbot[-1][1] += token
120
+ yield chatbot
121
+
122
+
123
+ def get_system_prompt_component(interactive: bool) -> gr.Textbox:
124
+ value = '' if interactive else 'System prompt is not supported by this model'
125
+ return gr.Textbox(value=value, label='System prompt', interactive=interactive)
126
+
127
+
128
+ def get_generate_args(do_sample: bool) -> List[gr.component]:
129
+ visible = do_sample
130
+ generate_args = [
131
+ gr.Slider(label='temperature', value=GENERATE_KWARGS['temperature'], minimum=0.1, maximum=3, step=0.1, visible=visible),
132
+ gr.Slider(label='top_p', value=GENERATE_KWARGS['top_p'], minimum=0.1, maximum=1, step=0.1, visible=visible),
133
+ gr.Slider(label='top_k', value=GENERATE_KWARGS['top_k'], minimum=1, maximum=50, step=5, visible=visible),
134
+ gr.Slider(label='repeat_penalty', value=GENERATE_KWARGS['repeat_penalty'], minimum=1, maximum=5, step=0.1, visible=visible),
135
+ ]
136
+ return generate_args
137
+
138
+
139
+ # ================== VARIABLES =============================
140
+
141
+ MODELS_PATH = Path('models')
142
+ MODELS_PATH.mkdir(exist_ok=True)
143
+ DEFAULT_GGUF_URL = 'https://huggingface.co/bartowski/gemma-2-2b-it-GGUF/resolve/main/gemma-2-2b-it-Q8_0.gguf'
144
+
145
+ start_model_dict, start_support_system_role, start_load_log = download_gguf_and_init_model(
146
+ gguf_url=DEFAULT_GGUF_URL, model_dict={},
147
+ )
148
+
149
+ GENERATE_KWARGS = dict(
150
+ temperature=0.2,
151
+ top_p=0.95,
152
+ top_k=40,
153
+ repeat_penalty=1.0,
154
+ )
155
+
156
+ theme = gr.themes.Base(primary_hue='green', secondary_hue='yellow', neutral_hue='zinc').set(
157
+ loader_color='rgb(0, 255, 0)',
158
+ slider_color='rgb(0, 200, 0)',
159
+ body_text_color_dark='rgb(0, 200, 0)',
160
+ button_secondary_background_fill_dark='green',
161
+ )
162
+ css = '''.gradio-container {width: 60% !important}'''
163
+
164
+
165
+ # ================== INTERFACE =============================
166
+
167
+ with gr.Blocks(theme=theme, css=css) as interface:
168
+ model_dict = gr.State(start_model_dict)
169
+ support_system_role = gr.State(start_support_system_role)
170
+
171
+ # ================= CHAT BOT PAGE ======================
172
+ with gr.Tab('Chat bot'):
173
+ with gr.Row():
174
+ with gr.Column(scale=3):
175
+ chatbot = gr.Chatbot(show_copy_button=True, bubble_full_width=False, height=480)
176
+ user_message = gr.Textbox(label='User')
177
+
178
+ with gr.Row():
179
+ user_message_btn = gr.Button('Send')
180
+ stop_btn = gr.Button('Stop')
181
+ clear_btn = gr.Button('Clear')
182
+
183
+ system_prompt = get_system_prompt_component(interactive=support_system_role.value)
184
+
185
+ with gr.Column(scale=1, min_width=80):
186
+ with gr.Group():
187
+ gr.Markdown('Length of message history')
188
+ history_len = gr.Slider(
189
+ minimum=0,
190
+ maximum=10,
191
+ value=0,
192
+ step=1,
193
+ info='Number of previous messages taken into account in history',
194
+ label='history_len',
195
+ show_label=False,
196
+ )
197
+
198
+ with gr.Group():
199
+ gr.Markdown('Generation parameters')
200
+ do_sample = gr.Checkbox(
201
+ value=False,
202
+ label='do_sample',
203
+ info='Activate random sampling',
204
+ )
205
+ generate_args = get_generate_args(do_sample.value)
206
+ do_sample.change(
207
+ fn=get_generate_args,
208
+ inputs=do_sample,
209
+ outputs=generate_args,
210
+ show_progress=False,
211
+ )
212
+
213
+ generate_event = gr.on(
214
+ triggers=[user_message.submit, user_message_btn.click],
215
+ fn=user_message_to_chatbot,
216
+ inputs=[user_message, chatbot],
217
+ outputs=[user_message, chatbot],
218
+ ).then(
219
+ fn=bot_response_to_chatbot,
220
+ inputs=[chatbot, model_dict, system_prompt, support_system_role, history_len, do_sample, *generate_args],
221
+ outputs=[chatbot],
222
+ )
223
+ stop_btn.click(
224
+ fn=None,
225
+ inputs=None,
226
+ outputs=None,
227
+ cancels=generate_event,
228
+ )
229
+ clear_btn.click(
230
+ fn=lambda: None,
231
+ inputs=None,
232
+ outputs=[chatbot],
233
+ )
234
+
235
+ # ================= LOAD MODELS PAGE ======================
236
+ with gr.Tab('Load model'):
237
+ gguf_url = gr.Textbox(
238
+ value='',
239
+ label='Link to GGUF',
240
+ placeholder='URL link to the model in GGUF format',
241
+ )
242
+ load_model_btn = gr.Button('Downloading GGUF and initializing the model')
243
+ load_log = gr.Textbox(
244
+ value=start_load_log,
245
+ label='Model loading status',
246
+ lines=3,
247
+ )
248
+
249
+ load_model_btn.click(
250
+ fn=download_gguf_and_init_model,
251
+ inputs=[gguf_url, model_dict],
252
+ outputs=[model_dict, support_system_role, load_log],
253
+ ).success(
254
+ fn=get_system_prompt_component,
255
+ inputs=[support_system_role],
256
+ outputs=[system_prompt],
257
+ )
258
+
259
+ interface.launch(server_name='0.0.0.0', server_port=7860)
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ llama_cpp_python==0.2.88
2
+ gradio>4