johnsmith253325 commited on
Commit
c51b92e
·
2 Parent(s): 406ae44 8dbdf7a

Merge branch 'main' of https://github.com/GaiZhenbiao/ChuanhuChatGPT

Browse files
modules/models/OpenAI.py CHANGED
@@ -149,13 +149,13 @@ class OpenAIClient(BaseLLMModel):
149
  timeout = TIMEOUT_ALL
150
 
151
  # 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
152
- if shared.state.completion_url != COMPLETION_URL:
153
- logging.debug(f"使用自定义API URL: {shared.state.completion_url}")
154
 
155
  with retrieve_proxy():
156
  try:
157
  response = requests.post(
158
- shared.state.completion_url,
159
  headers=headers,
160
  json=payload,
161
  stream=stream,
@@ -201,15 +201,23 @@ class OpenAIClient(BaseLLMModel):
201
  print(i18n("JSON解析错误,收到的内容: ") + f"{chunk}")
202
  error_msg += chunk
203
  continue
204
- if chunk_length > 6 and "delta" in chunk["choices"][0]:
205
- if chunk["choices"][0]["finish_reason"] == "stop":
206
- break
207
- try:
208
- yield chunk["choices"][0]["delta"]["content"]
209
- except Exception as e:
210
- # logging.error(f"Error: {e}")
211
- continue
212
- if error_msg:
 
 
 
 
 
 
 
 
213
  raise Exception(error_msg)
214
 
215
  def set_key(self, new_access_key):
@@ -229,12 +237,12 @@ class OpenAIClient(BaseLLMModel):
229
  "messages": history,
230
  }
231
  # 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
232
- if shared.state.completion_url != COMPLETION_URL:
233
- logging.debug(f"使用自定义API URL: {shared.state.completion_url}")
234
 
235
  with retrieve_proxy():
236
  response = requests.post(
237
- shared.state.completion_url,
238
  headers=headers,
239
  json=payload,
240
  stream=False,
@@ -245,7 +253,7 @@ class OpenAIClient(BaseLLMModel):
245
 
246
 
247
  def auto_name_chat_history(self, name_chat_method, user_question, chatbot, user_name, single_turn_checkbox):
248
- if len(self.history) == 2 and not single_turn_checkbox:
249
  user_question = self.history[0]["content"]
250
  if name_chat_method == i18n("模型自动总结(消耗tokens)"):
251
  ai_answer = self.history[1]["content"]
 
149
  timeout = TIMEOUT_ALL
150
 
151
  # 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
152
+ if shared.state.chat_completion_url != CHAT_COMPLETION_URL:
153
+ logging.debug(f"使用自定义API URL: {shared.state.chat_completion_url}")
154
 
155
  with retrieve_proxy():
156
  try:
157
  response = requests.post(
158
+ shared.state.chat_completion_url,
159
  headers=headers,
160
  json=payload,
161
  stream=stream,
 
201
  print(i18n("JSON解析错误,收到的内容: ") + f"{chunk}")
202
  error_msg += chunk
203
  continue
204
+ try:
205
+ if chunk_length > 6 and "delta" in chunk["choices"][0]:
206
+ if "finish_reason" in chunk["choices"][0]:
207
+ finish_reason = chunk["choices"][0]["finish_reason"]
208
+ else:
209
+ finish_reason = chunk["finish_reason"]
210
+ if finish_reason == "stop":
211
+ break
212
+ try:
213
+ yield chunk["choices"][0]["delta"]["content"]
214
+ except Exception as e:
215
+ # logging.error(f"Error: {e}")
216
+ continue
217
+ except:
218
+ print(f"ERROR: {chunk}")
219
+ continue
220
+ if error_msg and not error_msg=="data: [DONE]":
221
  raise Exception(error_msg)
222
 
223
  def set_key(self, new_access_key):
 
237
  "messages": history,
238
  }
239
  # 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
240
+ if shared.state.chat_completion_url != CHAT_COMPLETION_URL:
241
+ logging.debug(f"使用自定义API URL: {shared.state.chat_completion_url}")
242
 
243
  with retrieve_proxy():
244
  response = requests.post(
245
+ shared.state.chat_completion_url,
246
  headers=headers,
247
  json=payload,
248
  stream=False,
 
253
 
254
 
255
  def auto_name_chat_history(self, name_chat_method, user_question, chatbot, user_name, single_turn_checkbox):
256
+ if len(self.history) == 2 and not single_turn_checkbox and not hide_history_when_not_logged_in:
257
  user_question = self.history[0]["content"]
258
  if name_chat_method == i18n("模型自动总结(消耗tokens)"):
259
  ai_answer = self.history[1]["content"]
modules/models/OpenAIInstruct.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ from .base_model import BaseLLMModel
3
+ from .. import shared
4
+ from ..config import retrieve_proxy
5
+
6
+
7
+ class OpenAI_Instruct_Client(BaseLLMModel):
8
+ def __init__(self, model_name, api_key, user_name="") -> None:
9
+ super().__init__(model_name=model_name, user=user_name)
10
+ self.api_key = api_key
11
+
12
+ def _get_instruct_style_input(self):
13
+ return "\n\n".join([item["content"] for item in self.history])
14
+
15
+ @shared.state.switching_api_key
16
+ def get_answer_at_once(self):
17
+ prompt = self._get_instruct_style_input()
18
+ with retrieve_proxy():
19
+ response = openai.Completion.create(
20
+ api_key=self.api_key,
21
+ api_base=shared.state.openai_api_base,
22
+ model=self.model_name,
23
+ prompt=prompt,
24
+ temperature=self.temperature,
25
+ top_p=self.top_p,
26
+ )
27
+ return response.choices[0].text.strip(), response.usage["total_tokens"]
modules/models/base_model.py CHANGED
@@ -144,13 +144,17 @@ class ModelType(Enum):
144
  LangchainChat = 10
145
  Midjourney = 11
146
  Spark = 12
 
147
 
148
  @classmethod
149
  def get_type(cls, model_name: str):
150
  model_type = None
151
  model_name_lower = model_name.lower()
152
  if "gpt" in model_name_lower:
153
- model_type = ModelType.OpenAI
 
 
 
154
  elif "chatglm" in model_name_lower:
155
  model_type = ModelType.ChatGLM
156
  elif "llama" in model_name_lower or "alpaca" in model_name_lower:
@@ -247,7 +251,7 @@ class BaseLLMModel:
247
 
248
  def billing_info(self):
249
  """get billing infomation, inplement if needed"""
250
- logging.warning("billing info not implemented, using default")
251
  return BILLING_NOT_APPLICABLE_MSG
252
 
253
  def count_token(self, user_input):
@@ -743,6 +747,10 @@ class BaseLLMModel:
743
  logging.info(new_history)
744
  except:
745
  pass
 
 
 
 
746
  logging.debug(f"{self.user_identifier} 加载对话历史完毕")
747
  self.history = json_s["history"]
748
  return os.path.basename(self.history_file_path), json_s["system"], json_s["chatbot"]
 
144
  LangchainChat = 10
145
  Midjourney = 11
146
  Spark = 12
147
+ OpenAIInstruct = 13
148
 
149
  @classmethod
150
  def get_type(cls, model_name: str):
151
  model_type = None
152
  model_name_lower = model_name.lower()
153
  if "gpt" in model_name_lower:
154
+ if "instruct" in model_name_lower:
155
+ model_type = ModelType.OpenAIInstruct
156
+ else:
157
+ model_type = ModelType.OpenAI
158
  elif "chatglm" in model_name_lower:
159
  model_type = ModelType.ChatGLM
160
  elif "llama" in model_name_lower or "alpaca" in model_name_lower:
 
251
 
252
  def billing_info(self):
253
  """get billing infomation, inplement if needed"""
254
+ # logging.warning("billing info not implemented, using default")
255
  return BILLING_NOT_APPLICABLE_MSG
256
 
257
  def count_token(self, user_input):
 
747
  logging.info(new_history)
748
  except:
749
  pass
750
+ if len(json_s["chatbot"]) < len(json_s["history"]):
751
+ logging.info("Trimming corrupted history...")
752
+ json_s["history"] = json_s["history"][-len(json_s["chatbot"]):]
753
+ logging.info(f"Trimmed history: {json_s['history']}")
754
  logging.debug(f"{self.user_identifier} 加载对话历史完毕")
755
  self.history = json_s["history"]
756
  return os.path.basename(self.history_file_path), json_s["system"], json_s["chatbot"]
modules/models/models.py CHANGED
@@ -47,6 +47,12 @@ def get_model(
47
  top_p=top_p,
48
  user_name=user_name,
49
  )
 
 
 
 
 
 
50
  elif model_type == ModelType.ChatGLM:
51
  logging.info(f"正在加载ChatGLM模型: {model_name}")
52
  from .ChatGLM import ChatGLM_Client
 
47
  top_p=top_p,
48
  user_name=user_name,
49
  )
50
+ elif model_type == ModelType.OpenAIInstruct:
51
+ logging.info(f"正在加载OpenAI Instruct模型: {model_name}")
52
+ from .OpenAIInstruct import OpenAI_Instruct_Client
53
+ access_key = os.environ.get("OPENAI_API_KEY", access_key)
54
+ model = OpenAI_Instruct_Client(
55
+ model_name, api_key=access_key, user_name=user_name)
56
  elif model_type == ModelType.ChatGLM:
57
  logging.info(f"正在加载ChatGLM模型: {model_name}")
58
  from .ChatGLM import ChatGLM_Client
modules/presets.py CHANGED
@@ -14,7 +14,9 @@ LLAMA_INFERENCER = None
14
  # ChatGPT 设置
15
  INITIAL_SYSTEM_PROMPT = "You are a helpful assistant."
16
  API_HOST = "api.openai.com"
17
- COMPLETION_URL = "https://api.openai.com/v1/chat/completions"
 
 
18
  BALANCE_API_URL="https://api.openai.com/dashboard/billing/credit_grants"
19
  USAGE_API_URL="https://api.openai.com/dashboard/billing/usage"
20
  HISTORY_DIR = Path("history")
@@ -50,10 +52,11 @@ CHUANHU_DESCRIPTION = i18n("由Bilibili [土川虎虎虎](https://space.bilibili
50
 
51
  ONLINE_MODELS = [
52
  "gpt-3.5-turbo",
 
53
  "gpt-3.5-turbo-16k",
 
54
  "gpt-3.5-turbo-0301",
55
  "gpt-3.5-turbo-0613",
56
- "gpt-4",
57
  "gpt-4-0314",
58
  "gpt-4-0613",
59
  "gpt-4-32k",
 
14
  # ChatGPT 设置
15
  INITIAL_SYSTEM_PROMPT = "You are a helpful assistant."
16
  API_HOST = "api.openai.com"
17
+ OPENAI_API_BASE = "https://api.openai.com/v1"
18
+ CHAT_COMPLETION_URL = "https://api.openai.com/v1/chat/completions"
19
+ COMPLETION_URL = "https://api.openai.com/v1/completions"
20
  BALANCE_API_URL="https://api.openai.com/dashboard/billing/credit_grants"
21
  USAGE_API_URL="https://api.openai.com/dashboard/billing/usage"
22
  HISTORY_DIR = Path("history")
 
52
 
53
  ONLINE_MODELS = [
54
  "gpt-3.5-turbo",
55
+ "gpt-3.5-turbo-instruct",
56
  "gpt-3.5-turbo-16k",
57
+ "gpt-4",
58
  "gpt-3.5-turbo-0301",
59
  "gpt-3.5-turbo-0613",
 
60
  "gpt-4-0314",
61
  "gpt-4-0613",
62
  "gpt-4-32k",
modules/shared.py CHANGED
@@ -1,4 +1,4 @@
1
- from modules.presets import COMPLETION_URL, BALANCE_API_URL, USAGE_API_URL, API_HOST
2
  import os
3
  import queue
4
  import openai
@@ -6,9 +6,10 @@ import openai
6
  class State:
7
  interrupted = False
8
  multi_api_key = False
9
- completion_url = COMPLETION_URL
10
  balance_api_url = BALANCE_API_URL
11
  usage_api_url = USAGE_API_URL
 
12
 
13
  def interrupt(self):
14
  self.interrupted = True
@@ -22,13 +23,14 @@ class State:
22
  api_host = f"https://{api_host}"
23
  if api_host.endswith("/v1"):
24
  api_host = api_host[:-3]
25
- self.completion_url = f"{api_host}/v1/chat/completions"
 
26
  self.balance_api_url = f"{api_host}/dashboard/billing/credit_grants"
27
  self.usage_api_url = f"{api_host}/dashboard/billing/usage"
28
  os.environ["OPENAI_API_BASE"] = api_host
29
 
30
  def reset_api_host(self):
31
- self.completion_url = COMPLETION_URL
32
  self.balance_api_url = BALANCE_API_URL
33
  self.usage_api_url = USAGE_API_URL
34
  os.environ["OPENAI_API_BASE"] = f"https://{API_HOST}"
@@ -36,7 +38,7 @@ class State:
36
 
37
  def reset_all(self):
38
  self.interrupted = False
39
- self.completion_url = COMPLETION_URL
40
 
41
  def set_api_key_queue(self, api_key_list):
42
  self.multi_api_key = True
 
1
+ from modules.presets import CHAT_COMPLETION_URL, BALANCE_API_URL, USAGE_API_URL, API_HOST, OPENAI_API_BASE
2
  import os
3
  import queue
4
  import openai
 
6
  class State:
7
  interrupted = False
8
  multi_api_key = False
9
+ chat_completion_url = CHAT_COMPLETION_URL
10
  balance_api_url = BALANCE_API_URL
11
  usage_api_url = USAGE_API_URL
12
+ openai_api_base = OPENAI_API_BASE
13
 
14
  def interrupt(self):
15
  self.interrupted = True
 
23
  api_host = f"https://{api_host}"
24
  if api_host.endswith("/v1"):
25
  api_host = api_host[:-3]
26
+ self.chat_completion_url = f"{api_host}/v1/chat/completions"
27
+ self.openai_api_base = f"{api_host}/v1"
28
  self.balance_api_url = f"{api_host}/dashboard/billing/credit_grants"
29
  self.usage_api_url = f"{api_host}/dashboard/billing/usage"
30
  os.environ["OPENAI_API_BASE"] = api_host
31
 
32
  def reset_api_host(self):
33
+ self.chat_completion_url = CHAT_COMPLETION_URL
34
  self.balance_api_url = BALANCE_API_URL
35
  self.usage_api_url = USAGE_API_URL
36
  os.environ["OPENAI_API_BASE"] = f"https://{API_HOST}"
 
38
 
39
  def reset_all(self):
40
  self.interrupted = False
41
+ self.chat_completion_url = CHAT_COMPLETION_URL
42
 
43
  def set_api_key_queue(self, api_key_list):
44
  self.multi_api_key = True
modules/utils.py CHANGED
@@ -77,9 +77,6 @@ def auto_name_chat_history(current_model, *args):
77
  def export_markdown(current_model, *args):
78
  return current_model.export_markdown(*args)
79
 
80
- def load_chat_history(current_model, *args):
81
- return current_model.load_chat_history(*args)
82
-
83
  def upload_chat_history(current_model, *args):
84
  return current_model.load_chat_history(*args)
85
 
@@ -335,6 +332,8 @@ def construct_assistant(text):
335
 
336
  def save_file(filename, system, history, chatbot, user_name):
337
  os.makedirs(os.path.join(HISTORY_DIR, user_name), exist_ok=True)
 
 
338
  if filename.endswith(".md"):
339
  filename = filename[:-3]
340
  if not filename.endswith(".json") and not filename.endswith(".md"):
 
77
  def export_markdown(current_model, *args):
78
  return current_model.export_markdown(*args)
79
 
 
 
 
80
  def upload_chat_history(current_model, *args):
81
  return current_model.load_chat_history(*args)
82
 
 
332
 
333
  def save_file(filename, system, history, chatbot, user_name):
334
  os.makedirs(os.path.join(HISTORY_DIR, user_name), exist_ok=True)
335
+ if filename is None:
336
+ filename = new_auto_history_filename(user_name)
337
  if filename.endswith(".md"):
338
  filename = filename[:-3]
339
  if not filename.endswith(".json") and not filename.endswith(".md"):