Spaces:
Sleeping
Sleeping
Tuchuanhuhuhu
commited on
Commit
·
0127941
1
Parent(s):
9a2b13d
加入了运行时切换模型的功能
Browse files- ChuanhuChatbot.py +5 -5
- modules/base_model.py +7 -8
- modules/models.py +92 -6
ChuanhuChatbot.py
CHANGED
@@ -10,7 +10,7 @@ from modules.config import *
|
|
10 |
from modules.utils import *
|
11 |
from modules.presets import *
|
12 |
from modules.overwrites import *
|
13 |
-
from modules.models import
|
14 |
|
15 |
gr.Chatbot.postprocess = postprocess
|
16 |
PromptHelper.compact_text_chunks = compact_text_chunks
|
@@ -22,7 +22,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
22 |
user_name = gr.State("")
|
23 |
promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
|
24 |
user_question = gr.State("")
|
25 |
-
current_model = gr.State(
|
26 |
|
27 |
topic = gr.State("未命名对话历史记录")
|
28 |
|
@@ -197,7 +197,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
197 |
interactive=True,
|
198 |
label="max context",
|
199 |
)
|
200 |
-
|
201 |
minimum=1,
|
202 |
maximum=32768,
|
203 |
value=1000,
|
@@ -350,7 +350,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
350 |
# LLM Models
|
351 |
keyTxt.change(current_model.value.set_key, keyTxt, [status_display]).then(**get_usage_args)
|
352 |
keyTxt.submit(**get_usage_args)
|
353 |
-
model_select_dropdown.change(get_model, [model_select_dropdown, keyTxt, temperature_slider, top_p_slider, systemPromptTxt], [
|
354 |
|
355 |
# Template
|
356 |
systemPromptTxt.change(current_model.value.set_system_prompt, [systemPromptTxt], None)
|
@@ -392,7 +392,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
392 |
top_p_slider.change(current_model.value.set_top_p, [top_p_slider], None)
|
393 |
n_choices_slider.change(current_model.value.set_n_choices, [n_choices_slider], None)
|
394 |
stop_sequence_txt.change(current_model.value.set_stop_sequence, [stop_sequence_txt], None)
|
395 |
-
|
396 |
presence_penalty_slider.change(current_model.value.set_presence_penalty, [presence_penalty_slider], None)
|
397 |
frequency_penalty_slider.change(current_model.value.set_frequency_penalty, [frequency_penalty_slider], None)
|
398 |
logit_bias_txt.change(current_model.value.set_logit_bias, [logit_bias_txt], None)
|
|
|
10 |
from modules.utils import *
|
11 |
from modules.presets import *
|
12 |
from modules.overwrites import *
|
13 |
+
from modules.models import ModelManager
|
14 |
|
15 |
gr.Chatbot.postprocess = postprocess
|
16 |
PromptHelper.compact_text_chunks = compact_text_chunks
|
|
|
22 |
user_name = gr.State("")
|
23 |
promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
|
24 |
user_question = gr.State("")
|
25 |
+
current_model = gr.State(ModelManager(model_name = MODELS[DEFAULT_MODEL], access_key = my_api_key))
|
26 |
|
27 |
topic = gr.State("未命名对话历史记录")
|
28 |
|
|
|
197 |
interactive=True,
|
198 |
label="max context",
|
199 |
)
|
200 |
+
max_generation_slider = gr.Slider(
|
201 |
minimum=1,
|
202 |
maximum=32768,
|
203 |
value=1000,
|
|
|
350 |
# LLM Models
|
351 |
keyTxt.change(current_model.value.set_key, keyTxt, [status_display]).then(**get_usage_args)
|
352 |
keyTxt.submit(**get_usage_args)
|
353 |
+
model_select_dropdown.change(current_model.value.get_model, [model_select_dropdown, keyTxt, temperature_slider, top_p_slider, systemPromptTxt], [status_display], show_progress=True)
|
354 |
|
355 |
# Template
|
356 |
systemPromptTxt.change(current_model.value.set_system_prompt, [systemPromptTxt], None)
|
|
|
392 |
top_p_slider.change(current_model.value.set_top_p, [top_p_slider], None)
|
393 |
n_choices_slider.change(current_model.value.set_n_choices, [n_choices_slider], None)
|
394 |
stop_sequence_txt.change(current_model.value.set_stop_sequence, [stop_sequence_txt], None)
|
395 |
+
max_generation_slider.change(current_model.value.set_max_tokens, [max_generation_slider], None)
|
396 |
presence_penalty_slider.change(current_model.value.set_presence_penalty, [presence_penalty_slider], None)
|
397 |
frequency_penalty_slider.change(current_model.value.set_frequency_penalty, [frequency_penalty_slider], None)
|
398 |
logit_bias_txt.change(current_model.value.set_logit_bias, [logit_bias_txt], None)
|
modules/base_model.py
CHANGED
@@ -126,13 +126,15 @@ class BaseLLMModel:
|
|
126 |
|
127 |
stream_iter = self.get_answer_stream_iter()
|
128 |
|
129 |
-
self.history.append(construct_assistant(""))
|
130 |
for partial_text in stream_iter:
|
131 |
-
self.history[-1] = construct_assistant(partial_text)
|
132 |
chatbot[-1] = (chatbot[-1][0], partial_text + display_append)
|
133 |
self.all_token_counts[-1] += 1
|
134 |
status_text = self.token_message()
|
135 |
yield get_return_value()
|
|
|
|
|
|
|
|
|
136 |
|
137 |
def next_chatbot_at_once(self, inputs, chatbot, fake_input=None, display_append=""):
|
138 |
if fake_input:
|
@@ -277,9 +279,6 @@ class BaseLLMModel:
|
|
277 |
)
|
278 |
for chatbot, status_text in iter:
|
279 |
yield chatbot, status_text
|
280 |
-
if self.interrupted:
|
281 |
-
self.recover()
|
282 |
-
break
|
283 |
else:
|
284 |
logging.debug("不使用流式传输")
|
285 |
chatbot, status_text = self.next_chatbot_at_once(
|
@@ -326,13 +325,13 @@ class BaseLLMModel:
|
|
326 |
files=None,
|
327 |
reply_language="中文",
|
328 |
):
|
329 |
-
logging.
|
330 |
if len(self.history) == 0:
|
331 |
yield chatbot, f"{STANDARD_ERROR_MSG}上下文是空的"
|
332 |
return
|
333 |
|
|
|
334 |
del self.history[-2:]
|
335 |
-
inputs = chatbot[-1][0]
|
336 |
self.all_token_counts.pop()
|
337 |
iter = self.predict(
|
338 |
inputs,
|
@@ -344,7 +343,7 @@ class BaseLLMModel:
|
|
344 |
)
|
345 |
for x in iter:
|
346 |
yield x
|
347 |
-
logging.
|
348 |
|
349 |
# def reduce_token_size(self, chatbot):
|
350 |
# logging.info("开始减少token数量……")
|
|
|
126 |
|
127 |
stream_iter = self.get_answer_stream_iter()
|
128 |
|
|
|
129 |
for partial_text in stream_iter:
|
|
|
130 |
chatbot[-1] = (chatbot[-1][0], partial_text + display_append)
|
131 |
self.all_token_counts[-1] += 1
|
132 |
status_text = self.token_message()
|
133 |
yield get_return_value()
|
134 |
+
if self.interrupted:
|
135 |
+
self.recover()
|
136 |
+
break
|
137 |
+
self.history.append(construct_assistant(partial_text))
|
138 |
|
139 |
def next_chatbot_at_once(self, inputs, chatbot, fake_input=None, display_append=""):
|
140 |
if fake_input:
|
|
|
279 |
)
|
280 |
for chatbot, status_text in iter:
|
281 |
yield chatbot, status_text
|
|
|
|
|
|
|
282 |
else:
|
283 |
logging.debug("不使用流式传输")
|
284 |
chatbot, status_text = self.next_chatbot_at_once(
|
|
|
325 |
files=None,
|
326 |
reply_language="中文",
|
327 |
):
|
328 |
+
logging.debug("重试中……")
|
329 |
if len(self.history) == 0:
|
330 |
yield chatbot, f"{STANDARD_ERROR_MSG}上下文是空的"
|
331 |
return
|
332 |
|
333 |
+
inputs = self.history[-2]["content"]
|
334 |
del self.history[-2:]
|
|
|
335 |
self.all_token_counts.pop()
|
336 |
iter = self.predict(
|
337 |
inputs,
|
|
|
343 |
)
|
344 |
for x in iter:
|
345 |
yield x
|
346 |
+
logging.debug("重试完毕")
|
347 |
|
348 |
# def reduce_token_size(self, chatbot):
|
349 |
# logging.info("开始减少token数量……")
|
modules/models.py
CHANGED
@@ -247,7 +247,7 @@ class ChatGLM_Client(BaseLLMModel):
|
|
247 |
def _get_glm_style_input(self):
|
248 |
history = [x["content"] for x in self.history]
|
249 |
query = history.pop()
|
250 |
-
logging.
|
251 |
assert (
|
252 |
len(history) % 2 == 0
|
253 |
), f"History should be even length. current history is: {history}"
|
@@ -365,11 +365,12 @@ class LLaMA_Client(BaseLLMModel):
|
|
365 |
|
366 |
class ModelManager:
|
367 |
def __init__(self, **kwargs) -> None:
|
368 |
-
self.
|
369 |
|
370 |
def get_model(
|
371 |
self,
|
372 |
model_name,
|
|
|
373 |
access_key=None,
|
374 |
temperature=None,
|
375 |
top_p=None,
|
@@ -378,7 +379,6 @@ class ModelManager:
|
|
378 |
msg = f"模型设置为了: {model_name}"
|
379 |
logging.info(msg)
|
380 |
model_type = ModelType.get_type(model_name)
|
381 |
-
print(model_type.name)
|
382 |
if model_type == ModelType.OpenAI:
|
383 |
model = OpenAIClient(
|
384 |
model_name=model_name,
|
@@ -389,7 +389,93 @@ class ModelManager:
|
|
389 |
)
|
390 |
elif model_type == ModelType.ChatGLM:
|
391 |
model = ChatGLM_Client(model_name)
|
392 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
393 |
|
394 |
|
395 |
if __name__ == "__main__":
|
@@ -397,8 +483,8 @@ if __name__ == "__main__":
|
|
397 |
openai_api_key = cjson.load(f)["openai_api_key"]
|
398 |
# set logging level to debug
|
399 |
logging.basicConfig(level=logging.DEBUG)
|
400 |
-
# client
|
401 |
-
client
|
402 |
chatbot = []
|
403 |
stream = False
|
404 |
# 测试账单功能
|
|
|
247 |
def _get_glm_style_input(self):
|
248 |
history = [x["content"] for x in self.history]
|
249 |
query = history.pop()
|
250 |
+
logging.debug(colorama.Fore.YELLOW + f"{history}" + colorama.Fore.RESET)
|
251 |
assert (
|
252 |
len(history) % 2 == 0
|
253 |
), f"History should be even length. current history is: {history}"
|
|
|
365 |
|
366 |
class ModelManager:
|
367 |
def __init__(self, **kwargs) -> None:
|
368 |
+
self.get_model(**kwargs)
|
369 |
|
370 |
def get_model(
|
371 |
self,
|
372 |
model_name,
|
373 |
+
lora_model_path=None,
|
374 |
access_key=None,
|
375 |
temperature=None,
|
376 |
top_p=None,
|
|
|
379 |
msg = f"模型设置为了: {model_name}"
|
380 |
logging.info(msg)
|
381 |
model_type = ModelType.get_type(model_name)
|
|
|
382 |
if model_type == ModelType.OpenAI:
|
383 |
model = OpenAIClient(
|
384 |
model_name=model_name,
|
|
|
389 |
)
|
390 |
elif model_type == ModelType.ChatGLM:
|
391 |
model = ChatGLM_Client(model_name)
|
392 |
+
self.model = model
|
393 |
+
return msg
|
394 |
+
|
395 |
+
def predict(self, *args):
|
396 |
+
iter = self.model.predict(*args)
|
397 |
+
for i in iter:
|
398 |
+
yield i
|
399 |
+
|
400 |
+
def billing_info(self):
|
401 |
+
return self.model.billing_info()
|
402 |
+
|
403 |
+
def set_key(self, *args):
|
404 |
+
return self.model.set_key(*args)
|
405 |
+
|
406 |
+
def load_chat_history(self, *args):
|
407 |
+
return self.model.load_chat_history(*args)
|
408 |
+
|
409 |
+
def interrupt(self, *args):
|
410 |
+
return self.model.interrupt(*args)
|
411 |
+
|
412 |
+
def reset(self, *args):
|
413 |
+
return self.model.reset(*args)
|
414 |
+
|
415 |
+
def retry(self, *args):
|
416 |
+
iter = self.model.retry(*args)
|
417 |
+
for i in iter:
|
418 |
+
yield i
|
419 |
+
|
420 |
+
def delete_first_conversation(self, *args):
|
421 |
+
return self.model.delete_first_conversation(*args)
|
422 |
+
|
423 |
+
def delete_last_conversation(self, *args):
|
424 |
+
return self.model.delete_last_conversation(*args)
|
425 |
+
|
426 |
+
def set_system_prompt(self, *args):
|
427 |
+
return self.model.set_system_prompt(*args)
|
428 |
+
|
429 |
+
def save_chat_history(self, *args):
|
430 |
+
return self.model.save_chat_history(*args)
|
431 |
+
|
432 |
+
def export_markdown(self, *args):
|
433 |
+
return self.model.export_markdown(*args)
|
434 |
+
|
435 |
+
def load_chat_history(self, *args):
|
436 |
+
return self.model.load_chat_history(*args)
|
437 |
+
|
438 |
+
def set_token_upper_limit(self, *args):
|
439 |
+
return self.model.set_token_upper_limit(*args)
|
440 |
+
|
441 |
+
# temperature_slider.change(current_model.value.set_temperature, [temperature_slider], None)
|
442 |
+
# top_p_slider.change(current_model.value.set_top_p, [top_p_slider], None)
|
443 |
+
# n_choices_slider.change(current_model.value.set_n_choices, [n_choices_slider], None)
|
444 |
+
# stop_sequence_txt.change(current_model.value.set_stop_sequence, [stop_sequence_txt], None)
|
445 |
+
# max_tokens_slider.change(current_model.value.set_max_tokens, [max_tokens_slider], None)
|
446 |
+
# presence_penalty_slider.change(current_model.value.set_presence_penalty, [presence_penalty_slider], None)
|
447 |
+
# frequency_penalty_slider.change(current_model.value.set_frequency_penalty, [frequency_penalty_slider], None)
|
448 |
+
# logit_bias_txt.change(current_model.value.set_logit_bias, [logit_bias_txt], None)
|
449 |
+
# user_identifier_txt.change(current_model.value.set_user_identifier, [user_identifier_txt], None)
|
450 |
+
|
451 |
+
def set_temperature(self, *args):
|
452 |
+
self.model.set_temperature(*args)
|
453 |
+
|
454 |
+
def set_top_p(self, *args):
|
455 |
+
self.model.set_top_p(*args)
|
456 |
+
|
457 |
+
def set_n_choices(self, *args):
|
458 |
+
self.model.set_n_choices(*args)
|
459 |
+
|
460 |
+
def set_stop_sequence(self, *args):
|
461 |
+
self.model.set_stop_sequence(*args)
|
462 |
+
|
463 |
+
def set_max_tokens(self, *args):
|
464 |
+
self.model.set_max_tokens(*args)
|
465 |
+
|
466 |
+
def set_presence_penalty(self, *args):
|
467 |
+
self.model.set_presence_penalty(*args)
|
468 |
+
|
469 |
+
def set_frequency_penalty(self, *args):
|
470 |
+
self.model.set_frequency_penalty(*args)
|
471 |
+
|
472 |
+
def set_logit_bias(self, *args):
|
473 |
+
self.model.set_logit_bias(*args)
|
474 |
+
|
475 |
+
def set_user_identifier(self, *args):
|
476 |
+
self.model.set_user_identifier(*args)
|
477 |
+
|
478 |
+
|
479 |
|
480 |
|
481 |
if __name__ == "__main__":
|
|
|
483 |
openai_api_key = cjson.load(f)["openai_api_key"]
|
484 |
# set logging level to debug
|
485 |
logging.basicConfig(level=logging.DEBUG)
|
486 |
+
# client = ModelManager(model_name="gpt-3.5-turbo", access_key=openai_api_key)
|
487 |
+
client = ModelManager(model_name="chatglm-6b-int4")
|
488 |
chatbot = []
|
489 |
stream = False
|
490 |
# 测试账单功能
|