Spaces:
Sleeping
Sleeping
guest
commited on
Commit
·
066a302
1
Parent(s):
0a2a419
多用户使用同一个后端model
Browse files- ChuanhuChatbot.py +27 -26
- modules/models.py +43 -8
ChuanhuChatbot.py
CHANGED
@@ -15,6 +15,8 @@ from modules.models import ModelManager
|
|
15 |
gr.Chatbot.postprocess = postprocess
|
16 |
PromptHelper.compact_text_chunks = compact_text_chunks
|
17 |
|
|
|
|
|
18 |
with open("assets/custom.css", "r", encoding="utf-8") as f:
|
19 |
customCSS = f.read()
|
20 |
|
@@ -22,7 +24,6 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
22 |
user_name = gr.State("")
|
23 |
promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
|
24 |
user_question = gr.State("")
|
25 |
-
current_model = gr.State(ModelManager(model_name = MODELS[DEFAULT_MODEL], access_key = my_api_key))
|
26 |
|
27 |
topic = gr.State("未命名对话历史记录")
|
28 |
|
@@ -264,7 +265,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
264 |
gr.Markdown(CHUANHU_DESCRIPTION)
|
265 |
gr.HTML(FOOTER.format(versions=versions_html()), elem_id="footer")
|
266 |
chatgpt_predict_args = dict(
|
267 |
-
fn=current_model.
|
268 |
inputs=[
|
269 |
user_question,
|
270 |
chatbot,
|
@@ -297,18 +298,18 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
297 |
)
|
298 |
|
299 |
get_usage_args = dict(
|
300 |
-
fn=current_model.
|
301 |
)
|
302 |
|
303 |
load_history_from_file_args = dict(
|
304 |
-
fn=current_model.
|
305 |
inputs=[historyFileSelectDropdown, chatbot, user_name],
|
306 |
outputs=[saveFileName, systemPromptTxt, chatbot]
|
307 |
)
|
308 |
|
309 |
|
310 |
# Chatbot
|
311 |
-
cancelBtn.click(current_model.
|
312 |
|
313 |
user_input.submit(**transfer_input_args).then(**chatgpt_predict_args).then(**end_outputing_args)
|
314 |
user_input.submit(**get_usage_args)
|
@@ -317,14 +318,14 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
317 |
submitBtn.click(**get_usage_args)
|
318 |
|
319 |
emptyBtn.click(
|
320 |
-
current_model.
|
321 |
outputs=[chatbot, status_display],
|
322 |
show_progress=True,
|
323 |
)
|
324 |
emptyBtn.click(**reset_textbox_args)
|
325 |
|
326 |
retryBtn.click(**start_outputing_args).then(
|
327 |
-
current_model.
|
328 |
[
|
329 |
chatbot,
|
330 |
use_streaming_checkbox,
|
@@ -338,13 +339,13 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
338 |
retryBtn.click(**get_usage_args)
|
339 |
|
340 |
delFirstBtn.click(
|
341 |
-
current_model.
|
342 |
None,
|
343 |
[status_display],
|
344 |
)
|
345 |
|
346 |
delLastBtn.click(
|
347 |
-
current_model.
|
348 |
[chatbot],
|
349 |
[chatbot, status_display],
|
350 |
show_progress=False
|
@@ -353,14 +354,14 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
353 |
two_column.change(update_doc_config, [two_column], None)
|
354 |
|
355 |
# LLM Models
|
356 |
-
keyTxt.change(current_model.
|
357 |
keyTxt.submit(**get_usage_args)
|
358 |
-
single_turn_checkbox.change(current_model.
|
359 |
-
model_select_dropdown.change(current_model.
|
360 |
-
lora_select_dropdown.change(current_model.
|
361 |
|
362 |
# Template
|
363 |
-
systemPromptTxt.change(current_model.
|
364 |
templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
|
365 |
templateFileSelectDropdown.change(
|
366 |
load_template,
|
@@ -377,14 +378,14 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
377 |
|
378 |
# S&L
|
379 |
saveHistoryBtn.click(
|
380 |
-
current_model.
|
381 |
[saveFileName, chatbot, user_name],
|
382 |
downloadFile,
|
383 |
show_progress=True,
|
384 |
)
|
385 |
saveHistoryBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
|
386 |
exportMarkdownBtn.click(
|
387 |
-
current_model.
|
388 |
[saveFileName, chatbot, user_name],
|
389 |
downloadFile,
|
390 |
show_progress=True,
|
@@ -394,16 +395,16 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
394 |
downloadFile.change(**load_history_from_file_args)
|
395 |
|
396 |
# Advanced
|
397 |
-
max_context_length_slider.change(current_model.
|
398 |
-
temperature_slider.change(current_model.
|
399 |
-
top_p_slider.change(current_model.
|
400 |
-
n_choices_slider.change(current_model.
|
401 |
-
stop_sequence_txt.change(current_model.
|
402 |
-
max_generation_slider.change(current_model.
|
403 |
-
presence_penalty_slider.change(current_model.
|
404 |
-
frequency_penalty_slider.change(current_model.
|
405 |
-
logit_bias_txt.change(current_model.
|
406 |
-
user_identifier_txt.change(current_model.
|
407 |
|
408 |
default_btn.click(
|
409 |
reset_default, [], [apihostTxt, proxyTxt, status_display], show_progress=True
|
|
|
15 |
gr.Chatbot.postprocess = postprocess
|
16 |
PromptHelper.compact_text_chunks = compact_text_chunks
|
17 |
|
18 |
+
current_model = ModelManager(model_name = MODELS[DEFAULT_MODEL], access_key = my_api_key)
|
19 |
+
|
20 |
with open("assets/custom.css", "r", encoding="utf-8") as f:
|
21 |
customCSS = f.read()
|
22 |
|
|
|
24 |
user_name = gr.State("")
|
25 |
promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
|
26 |
user_question = gr.State("")
|
|
|
27 |
|
28 |
topic = gr.State("未命名对话历史记录")
|
29 |
|
|
|
265 |
gr.Markdown(CHUANHU_DESCRIPTION)
|
266 |
gr.HTML(FOOTER.format(versions=versions_html()), elem_id="footer")
|
267 |
chatgpt_predict_args = dict(
|
268 |
+
fn=current_model.predict,
|
269 |
inputs=[
|
270 |
user_question,
|
271 |
chatbot,
|
|
|
298 |
)
|
299 |
|
300 |
get_usage_args = dict(
|
301 |
+
fn=current_model.billing_info, inputs=None, outputs=[usageTxt], show_progress=False
|
302 |
)
|
303 |
|
304 |
load_history_from_file_args = dict(
|
305 |
+
fn=current_model.load_chat_history,
|
306 |
inputs=[historyFileSelectDropdown, chatbot, user_name],
|
307 |
outputs=[saveFileName, systemPromptTxt, chatbot]
|
308 |
)
|
309 |
|
310 |
|
311 |
# Chatbot
|
312 |
+
cancelBtn.click(current_model.interrupt, [], [])
|
313 |
|
314 |
user_input.submit(**transfer_input_args).then(**chatgpt_predict_args).then(**end_outputing_args)
|
315 |
user_input.submit(**get_usage_args)
|
|
|
318 |
submitBtn.click(**get_usage_args)
|
319 |
|
320 |
emptyBtn.click(
|
321 |
+
current_model.reset,
|
322 |
outputs=[chatbot, status_display],
|
323 |
show_progress=True,
|
324 |
)
|
325 |
emptyBtn.click(**reset_textbox_args)
|
326 |
|
327 |
retryBtn.click(**start_outputing_args).then(
|
328 |
+
current_model.retry,
|
329 |
[
|
330 |
chatbot,
|
331 |
use_streaming_checkbox,
|
|
|
339 |
retryBtn.click(**get_usage_args)
|
340 |
|
341 |
delFirstBtn.click(
|
342 |
+
current_model.delete_first_conversation,
|
343 |
None,
|
344 |
[status_display],
|
345 |
)
|
346 |
|
347 |
delLastBtn.click(
|
348 |
+
current_model.delete_last_conversation,
|
349 |
[chatbot],
|
350 |
[chatbot, status_display],
|
351 |
show_progress=False
|
|
|
354 |
two_column.change(update_doc_config, [two_column], None)
|
355 |
|
356 |
# LLM Models
|
357 |
+
keyTxt.change(current_model.set_key, keyTxt, [status_display]).then(**get_usage_args)
|
358 |
keyTxt.submit(**get_usage_args)
|
359 |
+
single_turn_checkbox.change(current_model.set_single_turn, single_turn_checkbox, None)
|
360 |
+
model_select_dropdown.change(current_model.get_model, [model_select_dropdown, lora_select_dropdown, keyTxt, temperature_slider, top_p_slider, systemPromptTxt], [status_display, lora_select_dropdown], show_progress=True)
|
361 |
+
lora_select_dropdown.change(current_model.get_model, [model_select_dropdown, lora_select_dropdown, keyTxt, temperature_slider, top_p_slider, systemPromptTxt], [status_display], show_progress=True)
|
362 |
|
363 |
# Template
|
364 |
+
systemPromptTxt.change(current_model.set_system_prompt, [systemPromptTxt], None)
|
365 |
templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
|
366 |
templateFileSelectDropdown.change(
|
367 |
load_template,
|
|
|
378 |
|
379 |
# S&L
|
380 |
saveHistoryBtn.click(
|
381 |
+
current_model.save_chat_history,
|
382 |
[saveFileName, chatbot, user_name],
|
383 |
downloadFile,
|
384 |
show_progress=True,
|
385 |
)
|
386 |
saveHistoryBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
|
387 |
exportMarkdownBtn.click(
|
388 |
+
current_model.export_markdown,
|
389 |
[saveFileName, chatbot, user_name],
|
390 |
downloadFile,
|
391 |
show_progress=True,
|
|
|
395 |
downloadFile.change(**load_history_from_file_args)
|
396 |
|
397 |
# Advanced
|
398 |
+
max_context_length_slider.change(current_model.set_token_upper_limit, [max_context_length_slider], None)
|
399 |
+
temperature_slider.change(current_model.set_temperature, [temperature_slider], None)
|
400 |
+
top_p_slider.change(current_model.set_top_p, [top_p_slider], None)
|
401 |
+
n_choices_slider.change(current_model.set_n_choices, [n_choices_slider], None)
|
402 |
+
stop_sequence_txt.change(current_model.set_stop_sequence, [stop_sequence_txt], None)
|
403 |
+
max_generation_slider.change(current_model.set_max_tokens, [max_generation_slider], None)
|
404 |
+
presence_penalty_slider.change(current_model.set_presence_penalty, [presence_penalty_slider], None)
|
405 |
+
frequency_penalty_slider.change(current_model.set_frequency_penalty, [frequency_penalty_slider], None)
|
406 |
+
logit_bias_txt.change(current_model.set_logit_bias, [logit_bias_txt], None)
|
407 |
+
user_identifier_txt.change(current_model.set_user_identifier, [user_identifier_txt], None)
|
408 |
|
409 |
default_btn.click(
|
410 |
reset_default, [], [apihostTxt, proxyTxt, status_display], show_progress=True
|
modules/models.py
CHANGED
@@ -195,7 +195,7 @@ class OpenAIClient(BaseLLMModel):
|
|
195 |
chunk = json.loads(chunk[6:])
|
196 |
except json.JSONDecodeError:
|
197 |
print(f"JSON解析错误,收到的内容: {chunk}")
|
198 |
-
error_msg+=chunk
|
199 |
continue
|
200 |
if chunk_length > 6 and "delta" in chunk["choices"][0]:
|
201 |
if chunk["choices"][0]["finish_reason"] == "stop":
|
@@ -216,7 +216,7 @@ class ChatGLM_Client(BaseLLMModel):
|
|
216 |
import torch
|
217 |
|
218 |
system_name = platform.system()
|
219 |
-
model_path=None
|
220 |
if os.path.exists("models"):
|
221 |
model_dirs = os.listdir("models")
|
222 |
if model_name in model_dirs:
|
@@ -292,6 +292,7 @@ class LLaMA_Client(BaseLLMModel):
|
|
292 |
from lmflow.pipeline.auto_pipeline import AutoPipeline
|
293 |
from lmflow.models.auto_model import AutoModel
|
294 |
from lmflow.args import ModelArguments, DatasetArguments, InferencerArguments
|
|
|
295 |
model_path = None
|
296 |
if os.path.exists("models"):
|
297 |
model_dirs = os.listdir("models")
|
@@ -304,10 +305,33 @@ class LLaMA_Client(BaseLLMModel):
|
|
304 |
# raise Exception(f"models目录下没有这个模型: {model_name}")
|
305 |
if lora_path is not None:
|
306 |
lora_path = f"lora/{lora_path}"
|
|
|
307 |
self.max_generation_token = 1000
|
308 |
pipeline_name = "inferencer"
|
309 |
-
model_args = ModelArguments(
|
310 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
311 |
|
312 |
with open(pipeline_args.deepspeed, "r") as f:
|
313 |
ds_config = json.load(f)
|
@@ -374,7 +398,7 @@ class LLaMA_Client(BaseLLMModel):
|
|
374 |
step = 1
|
375 |
for _ in range(0, self.max_generation_token, step):
|
376 |
input_dataset = self.dataset.from_dict(
|
377 |
-
{"type": "text_only", "instances": [{"text": context+partial_text}]}
|
378 |
)
|
379 |
output_dataset = self.inferencer.inference(
|
380 |
model=self.model,
|
@@ -404,6 +428,17 @@ class ModelManager:
|
|
404 |
system_prompt=None,
|
405 |
) -> BaseLLMModel:
|
406 |
msg = f"模型设置为了: {model_name}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
407 |
model_type = ModelType.get_type(model_name)
|
408 |
lora_selector_visibility = False
|
409 |
lora_choices = []
|
@@ -451,7 +486,9 @@ class ModelManager:
|
|
451 |
if dont_change_lora_selector:
|
452 |
return msg
|
453 |
else:
|
454 |
-
return msg, gr.Dropdown.update(
|
|
|
|
|
455 |
|
456 |
def predict(self, *args):
|
457 |
iter = self.model.predict(*args)
|
@@ -530,8 +567,6 @@ class ModelManager:
|
|
530 |
self.model.set_single_turn(*args)
|
531 |
|
532 |
|
533 |
-
|
534 |
-
|
535 |
if __name__ == "__main__":
|
536 |
with open("config.json", "r") as f:
|
537 |
openai_api_key = cjson.load(f)["openai_api_key"]
|
|
|
195 |
chunk = json.loads(chunk[6:])
|
196 |
except json.JSONDecodeError:
|
197 |
print(f"JSON解析错误,收到的内容: {chunk}")
|
198 |
+
error_msg += chunk
|
199 |
continue
|
200 |
if chunk_length > 6 and "delta" in chunk["choices"][0]:
|
201 |
if chunk["choices"][0]["finish_reason"] == "stop":
|
|
|
216 |
import torch
|
217 |
|
218 |
system_name = platform.system()
|
219 |
+
model_path = None
|
220 |
if os.path.exists("models"):
|
221 |
model_dirs = os.listdir("models")
|
222 |
if model_name in model_dirs:
|
|
|
292 |
from lmflow.pipeline.auto_pipeline import AutoPipeline
|
293 |
from lmflow.models.auto_model import AutoModel
|
294 |
from lmflow.args import ModelArguments, DatasetArguments, InferencerArguments
|
295 |
+
|
296 |
model_path = None
|
297 |
if os.path.exists("models"):
|
298 |
model_dirs = os.listdir("models")
|
|
|
305 |
# raise Exception(f"models目录下没有这个模型: {model_name}")
|
306 |
if lora_path is not None:
|
307 |
lora_path = f"lora/{lora_path}"
|
308 |
+
self.lora_path = lora_path
|
309 |
self.max_generation_token = 1000
|
310 |
pipeline_name = "inferencer"
|
311 |
+
model_args = ModelArguments(
|
312 |
+
model_name_or_path=model_source,
|
313 |
+
lora_model_path=lora_path,
|
314 |
+
model_type=None,
|
315 |
+
config_overrides=None,
|
316 |
+
config_name=None,
|
317 |
+
tokenizer_name=None,
|
318 |
+
cache_dir=None,
|
319 |
+
use_fast_tokenizer=True,
|
320 |
+
model_revision="main",
|
321 |
+
use_auth_token=False,
|
322 |
+
torch_dtype=None,
|
323 |
+
use_lora=False,
|
324 |
+
lora_r=8,
|
325 |
+
lora_alpha=32,
|
326 |
+
lora_dropout=0.1,
|
327 |
+
use_ram_optimized_load=True,
|
328 |
+
)
|
329 |
+
pipeline_args = InferencerArguments(
|
330 |
+
local_rank=0,
|
331 |
+
random_seed=1,
|
332 |
+
deepspeed="configs/ds_config_chatbot.json",
|
333 |
+
mixed_precision="bf16",
|
334 |
+
)
|
335 |
|
336 |
with open(pipeline_args.deepspeed, "r") as f:
|
337 |
ds_config = json.load(f)
|
|
|
398 |
step = 1
|
399 |
for _ in range(0, self.max_generation_token, step):
|
400 |
input_dataset = self.dataset.from_dict(
|
401 |
+
{"type": "text_only", "instances": [{"text": context + partial_text}]}
|
402 |
)
|
403 |
output_dataset = self.inferencer.inference(
|
404 |
model=self.model,
|
|
|
428 |
system_prompt=None,
|
429 |
) -> BaseLLMModel:
|
430 |
msg = f"模型设置为了: {model_name}"
|
431 |
+
if self.model is not None and model_name == self.model.model_name:
|
432 |
+
# 如果模型名字一样,那么就不用重新加载模型
|
433 |
+
# if (
|
434 |
+
# lora_model_path is not None
|
435 |
+
# and hasattr(self.model, "lora_path")
|
436 |
+
# and lora_model_path == self.model.lora_path
|
437 |
+
# or lora_model_path is None
|
438 |
+
# and not hasattr(self.model, "lora_path")
|
439 |
+
# ):
|
440 |
+
logging.info(f"模型 {model_name} 已经加载,不需要重新加载")
|
441 |
+
return msg
|
442 |
model_type = ModelType.get_type(model_name)
|
443 |
lora_selector_visibility = False
|
444 |
lora_choices = []
|
|
|
486 |
if dont_change_lora_selector:
|
487 |
return msg
|
488 |
else:
|
489 |
+
return msg, gr.Dropdown.update(
|
490 |
+
choices=lora_choices, visible=lora_selector_visibility
|
491 |
+
)
|
492 |
|
493 |
def predict(self, *args):
|
494 |
iter = self.model.predict(*args)
|
|
|
567 |
self.model.set_single_turn(*args)
|
568 |
|
569 |
|
|
|
|
|
570 |
if __name__ == "__main__":
|
571 |
with open("config.json", "r") as f:
|
572 |
openai_api_key = cjson.load(f)["openai_api_key"]
|