File size: 4,083 Bytes
f7161fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# -*- coding: utf-8 -*-
"""
Get model client from model name
"""

import json

import colorama
import gradio as gr
from loguru import logger

from src import config
from src.base_model import ModelType
from src.utils import (
    hide_middle_chars,
    i18n,
)


def get_model(
        model_name,
        lora_model_path=None,
        access_key=None,
        temperature=None,
        top_p=None,
        system_prompt=None,
        user_name="",
        original_model=None,
):
    msg = i18n("模型设置为了:") + f" {model_name}"
    model_type = ModelType.get_type(model_name)
    lora_choices = ["No LoRA"]
    if model_type != ModelType.OpenAI:
        config.local_embedding = True
    model = original_model
    chatbot = gr.Chatbot.update(label=model_name)
    try:
        if model_type == ModelType.OpenAI:
            logger.info(f"正在加载OpenAI模型: {model_name}")
            from src.openai_client import OpenAIClient
            model = OpenAIClient(
                model_name=model_name,
                api_key=access_key,
                system_prompt=system_prompt,
                user_name=user_name,
            )
        elif model_type == ModelType.OpenAIVision:
            logger.info(f"正在加载OpenAI Vision模型: {model_name}")
            from src.openai_client import OpenAIVisionClient
            model = OpenAIVisionClient(model_name, api_key=access_key, user_name=user_name)
        elif model_type == ModelType.ChatGLM:
            logger.info(f"正在加载ChatGLM模型: {model_name}")
            from src.chatglm import ChatGLMClient
            model = ChatGLMClient(model_name, user_name=user_name)
        elif model_type == ModelType.LLaMA:
            logger.info(f"正在加载LLaMA模型: {model_name}")
            from src.llama import LLaMAClient
            model = LLaMAClient(model_name, user_name=user_name)
        elif model_type == ModelType.ZhipuAI: # todo: fix zhipu bug
            logger.info(f"正在加载ZhipuAI模型: {model_name}")
            from src.zhipu_client import ZhipuAIClient
            model = ZhipuAIClient(
                model_name=model_name,
                api_key=access_key,
                system_prompt=system_prompt,
                user_name=user_name,
            )
        elif model_type == ModelType.Unknown:
            raise ValueError(f"未知模型: {model_name}")
    except Exception as e:
        logger.error(e)
    logger.info(msg)
    presudo_key = hide_middle_chars(access_key)
    if original_model is not None and model is not None:
        model.history = original_model.history
        model.history_file_path = original_model.history_file_path
    return model, msg, chatbot, gr.Dropdown.update(choices=lora_choices, visible=False), access_key, presudo_key


if __name__ == "__main__":
    with open("../config.json", "r") as f:
        openai_api_key = json.load(f)["openai_api_key"]
    print('key:', openai_api_key)
    client = get_model(model_name="gpt-3.5-turbo", access_key=openai_api_key)[0]
    chatbot = []
    stream = False
    # 测试账单功能
    logger.info(colorama.Back.GREEN + "测试账单功能" + colorama.Back.RESET)
    logger.info(client.billing_info())
    # 测试问答
    logger.info(colorama.Back.GREEN + "测试问答" + colorama.Back.RESET)
    question = "巴黎是中国的首都吗?"
    for i in client.predict(inputs=question, chatbot=chatbot, stream=stream):
        logger.info(i)
    logger.info(f"测试问答后history : {client.history}")
    # 测试记忆力
    logger.info(colorama.Back.GREEN + "测试记忆力" + colorama.Back.RESET)
    question = "我刚刚问了你什么问题?"
    for i in client.predict(inputs=question, chatbot=chatbot, stream=stream):
        logger.info(i)
    logger.info(f"测试记忆力后history : {client.history}")
    # 测试重试功能
    logger.info(colorama.Back.GREEN + "测试重试功能" + colorama.Back.RESET)
    for i in client.retry(chatbot=chatbot, stream=stream):
        logger.info(i)
    logger.info(f"重试后history : {client.history}")