Feliciano Long commited on
Commit
cef64b2
·
unverified ·
1 Parent(s): 0aaa1a4

feat: Add Minimax model (#774)

Browse files

* add minimax model and implement both get answer at once/stream, also the role play setting and token count

* add configurable minimax api_key and group id, add explaination to config example

config_example.json CHANGED
@@ -5,6 +5,9 @@
5
  "usage_limit": 120, // API Key的当月限额,单位:美元
6
  // 你的xmchat API Key,与OpenAI API Key不同
7
  "xmchat_api_key": "",
 
 
 
8
  "language": "auto",
9
  // 如果使用代理,请取消注释下面的两行,并替换代理URL
10
  // "https_proxy": "http://127.0.0.1:1079",
 
5
  "usage_limit": 120, // API Key的当月限额,单位:美元
6
  // 你的xmchat API Key,与OpenAI API Key不同
7
  "xmchat_api_key": "",
8
+ // MiniMax的APIKey(见账户管理页面 https://api.minimax.chat/basic-information)和Group ID,用于MiniMax对话模型
9
+ "minimax_api_key": "",
10
+ "minimax_group_id": "",
11
  "language": "auto",
12
  // 如果使用代理,请取消注释下面的两行,并替换代理URL
13
  // "https_proxy": "http://127.0.0.1:1079",
modules/config.py CHANGED
@@ -76,6 +76,11 @@ my_api_key = os.environ.get("OPENAI_API_KEY", my_api_key)
76
  xmchat_api_key = config.get("xmchat_api_key", "")
77
  os.environ["XMCHAT_API_KEY"] = xmchat_api_key
78
 
 
 
 
 
 
79
  render_latex = config.get("render_latex", True)
80
 
81
  if render_latex:
 
76
  xmchat_api_key = config.get("xmchat_api_key", "")
77
  os.environ["XMCHAT_API_KEY"] = xmchat_api_key
78
 
79
+ minimax_api_key = config.get("minimax_api_key", "")
80
+ os.environ["MINIMAX_API_KEY"] = minimax_api_key
81
+ minimax_group_id = config.get("minimax_group_id", "")
82
+ os.environ["MINIMAX_GROUP_ID"] = minimax_group_id
83
+
84
  render_latex = config.get("render_latex", True)
85
 
86
  if render_latex:
modules/models/base_model.py CHANGED
@@ -34,6 +34,7 @@ class ModelType(Enum):
34
  StableLM = 4
35
  MOSS = 5
36
  YuanAI = 6
 
37
 
38
  @classmethod
39
  def get_type(cls, model_name: str):
@@ -53,6 +54,8 @@ class ModelType(Enum):
53
  model_type = ModelType.MOSS
54
  elif "yuanai" in model_name_lower:
55
  model_type = ModelType.YuanAI
 
 
56
  else:
57
  model_type = ModelType.Unknown
58
  return model_type
 
34
  StableLM = 4
35
  MOSS = 5
36
  YuanAI = 6
37
+ Minimax = 7
38
 
39
  @classmethod
40
  def get_type(cls, model_name: str):
 
54
  model_type = ModelType.MOSS
55
  elif "yuanai" in model_name_lower:
56
  model_type = ModelType.YuanAI
57
+ elif "minimax" in model_name_lower:
58
+ model_type = ModelType.Minimax
59
  else:
60
  model_type = ModelType.Unknown
61
  return model_type
modules/models/minimax.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ import colorama
5
+ import requests
6
+ import logging
7
+
8
+ from modules.models.base_model import BaseLLMModel
9
+ from modules.presets import STANDARD_ERROR_MSG, GENERAL_ERROR_MSG, TIMEOUT_STREAMING, TIMEOUT_ALL, i18n
10
+
11
+ group_id = os.environ.get("MINIMAX_GROUP_ID", "")
12
+
13
+
14
+ class MiniMax_Client(BaseLLMModel):
15
+ """
16
+ MiniMax Client
17
+ 接口文档见 https://api.minimax.chat/document/guides/chat
18
+ """
19
+
20
+ def __init__(self, model_name, api_key, user_name="", system_prompt=None):
21
+ super().__init__(model_name=model_name, user=user_name)
22
+ self.url = f'https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}'
23
+ self.history = []
24
+ self.api_key = api_key
25
+ self.system_prompt = system_prompt
26
+ self.headers = {
27
+ "Authorization": f"Bearer {api_key}",
28
+ "Content-Type": "application/json"
29
+ }
30
+
31
+ def get_answer_at_once(self):
32
+ # minimax temperature is (0,1] and base model temperature is [0,2], and yuan 0.9 == base 1 so need to convert
33
+ temperature = self.temperature * 0.9 if self.temperature <= 1 else 0.9 + (self.temperature - 1) / 10
34
+
35
+ request_body = {
36
+ "model": self.model_name.replace('minimax-', ''),
37
+ "temperature": temperature,
38
+ "skip_info_mask": True,
39
+ 'messages': [{"sender_type": "USER", "text": self.history[-1]['content']}]
40
+ }
41
+ if self.n_choices:
42
+ request_body['beam_width'] = self.n_choices
43
+ if self.system_prompt:
44
+ request_body['prompt'] = self.system_prompt
45
+ if self.max_generation_token:
46
+ request_body['tokens_to_generate'] = self.max_generation_token
47
+ if self.top_p:
48
+ request_body['top_p'] = self.top_p
49
+
50
+ response = requests.post(self.url, headers=self.headers, json=request_body)
51
+
52
+ res = response.json()
53
+ answer = res['reply']
54
+ total_token_count = res["usage"]["total_tokens"]
55
+ return answer, total_token_count
56
+
57
+ def get_answer_stream_iter(self):
58
+ response = self._get_response(stream=True)
59
+ if response is not None:
60
+ iter = self._decode_chat_response(response)
61
+ partial_text = ""
62
+ for i in iter:
63
+ partial_text += i
64
+ yield partial_text
65
+ else:
66
+ yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG
67
+
68
+ def _get_response(self, stream=False):
69
+ minimax_api_key = self.api_key
70
+ history = self.history
71
+ logging.debug(colorama.Fore.YELLOW +
72
+ f"{history}" + colorama.Fore.RESET)
73
+ headers = {
74
+ "Content-Type": "application/json",
75
+ "Authorization": f"Bearer {minimax_api_key}",
76
+ }
77
+
78
+ temperature = self.temperature * 0.9 if self.temperature <= 1 else 0.9 + (self.temperature - 1) / 10
79
+
80
+ messages = []
81
+ for msg in self.history:
82
+ if msg['role'] == 'user':
83
+ messages.append({"sender_type": "USER", "text": msg['content']})
84
+ else:
85
+ messages.append({"sender_type": "BOT", "text": msg['content']})
86
+
87
+ request_body = {
88
+ "model": self.model_name.replace('minimax-', ''),
89
+ "temperature": temperature,
90
+ "skip_info_mask": True,
91
+ 'messages': messages
92
+ }
93
+ if self.n_choices:
94
+ request_body['beam_width'] = self.n_choices
95
+ if self.system_prompt:
96
+ lines = self.system_prompt.splitlines()
97
+ if lines[0].find(":") != -1 and len(lines[0]) < 20:
98
+ request_body["role_meta"] = {
99
+ "user_name": lines[0].split(":")[0],
100
+ "bot_name": lines[0].split(":")[1]
101
+ }
102
+ lines.pop()
103
+ request_body["prompt"] = "\n".join(lines)
104
+ if self.max_generation_token:
105
+ request_body['tokens_to_generate'] = self.max_generation_token
106
+ else:
107
+ request_body['tokens_to_generate'] = 512
108
+ if self.top_p:
109
+ request_body['top_p'] = self.top_p
110
+
111
+ if stream:
112
+ timeout = TIMEOUT_STREAMING
113
+ request_body['stream'] = True
114
+ request_body['use_standard_sse'] = True
115
+ else:
116
+ timeout = TIMEOUT_ALL
117
+ try:
118
+ response = requests.post(
119
+ self.url,
120
+ headers=headers,
121
+ json=request_body,
122
+ stream=stream,
123
+ timeout=timeout,
124
+ )
125
+ except:
126
+ return None
127
+
128
+ return response
129
+
130
+ def _decode_chat_response(self, response):
131
+ error_msg = ""
132
+ for chunk in response.iter_lines():
133
+ if chunk:
134
+ chunk = chunk.decode()
135
+ chunk_length = len(chunk)
136
+ print(chunk)
137
+ try:
138
+ chunk = json.loads(chunk[6:])
139
+ except json.JSONDecodeError:
140
+ print(i18n("JSON解析错误,��到的内容: ") + f"{chunk}")
141
+ error_msg += chunk
142
+ continue
143
+ if chunk_length > 6 and "delta" in chunk["choices"][0]:
144
+ if "finish_reason" in chunk["choices"][0] and chunk["choices"][0]["finish_reason"] == "stop":
145
+ self.all_token_counts.append(chunk["usage"]["total_tokens"] - sum(self.all_token_counts))
146
+ break
147
+ try:
148
+ yield chunk["choices"][0]["delta"]
149
+ except Exception as e:
150
+ logging.error(f"Error: {e}")
151
+ continue
152
+ if error_msg:
153
+ try:
154
+ error_msg = json.loads(error_msg)
155
+ if 'base_resp' in error_msg:
156
+ status_code = error_msg['base_resp']['status_code']
157
+ status_msg = error_msg['base_resp']['status_msg']
158
+ raise Exception(f"{status_code} - {status_msg}")
159
+ except json.JSONDecodeError:
160
+ pass
161
+ raise Exception(error_msg)
modules/models/models.py CHANGED
@@ -602,6 +602,11 @@ def get_model(
602
  elif model_type == ModelType.YuanAI:
603
  from .inspurai import Yuan_Client
604
  model = Yuan_Client(model_name, api_key=access_key, user_name=user_name, system_prompt=system_prompt)
 
 
 
 
 
605
  elif model_type == ModelType.Unknown:
606
  raise ValueError(f"未知模型: {model_name}")
607
  logging.info(msg)
 
602
  elif model_type == ModelType.YuanAI:
603
  from .inspurai import Yuan_Client
604
  model = Yuan_Client(model_name, api_key=access_key, user_name=user_name, system_prompt=system_prompt)
605
+ elif model_type == ModelType.Minimax:
606
+ from .minimax import MiniMax_Client
607
+ if os.environ.get("MINIMAX_API_KEY") != "":
608
+ access_key = os.environ.get("MINIMAX_API_KEY")
609
+ model = MiniMax_Client(model_name, api_key=access_key, user_name=user_name, system_prompt=system_prompt)
610
  elif model_type == ModelType.Unknown:
611
  raise ValueError(f"未知模型: {model_name}")
612
  logging.info(msg)
modules/presets.py CHANGED
@@ -72,6 +72,8 @@ ONLINE_MODELS = [
72
  "yuanai-1.0-translate",
73
  "yuanai-1.0-dialog",
74
  "yuanai-1.0-rhythm_poems",
 
 
75
  ]
76
 
77
  LOCAL_MODELS = [
 
72
  "yuanai-1.0-translate",
73
  "yuanai-1.0-dialog",
74
  "yuanai-1.0-rhythm_poems",
75
+ "minimax-abab4-chat",
76
+ "minimax-abab5-chat",
77
  ]
78
 
79
  LOCAL_MODELS = [