Tuchuanhuhuhu commited on
Commit
12eb16f
·
1 Parent(s): c407bb3

feat: 加入讯飞星火大模型支持 #877

Browse files
config_example.json CHANGED
@@ -11,6 +11,9 @@
11
  "midjourney_proxy_api_secret": "", // 你的 MidJourney Proxy API Secret,用于鉴权访问 api,可选
12
  "midjourney_discord_proxy_url": "", // 你的 MidJourney Discord Proxy URL,用于对生成对图进行反代,可选
13
  "midjourney_temp_folder": "./tmp", // 你的 MidJourney 临时文件夹,用于存放生成的图片,填空则关闭自动下载切图(直接显示MJ的四宫格图)
 
 
 
14
 
15
 
16
  //== Azure ==
 
11
  "midjourney_proxy_api_secret": "", // 你的 MidJourney Proxy API Secret,用于鉴权访问 api,可选
12
  "midjourney_discord_proxy_url": "", // 你的 MidJourney Discord Proxy URL,用于对生成对图进行反代,可选
13
  "midjourney_temp_folder": "./tmp", // 你的 MidJourney 临时文件夹,用于存放生成的图片,填空则关闭自动下载切图(直接显示MJ的四宫格图)
14
+ "spark_appid": "", // 你的 讯飞星火大模型 API AppID,用于讯飞星火大模型对话模型
15
+ "spark_api_key": "", // 你的 讯飞星火大模型 API Key,用于讯飞星火大模型对话模型
16
+ "spark_api_secret": "", // 你的 讯飞星火大模型 API Secret,用于讯飞星火大模型对话模型
17
 
18
 
19
  //== Azure ==
modules/config.py CHANGED
@@ -123,6 +123,13 @@ os.environ["MIDJOURNEY_DISCORD_PROXY_URL"] = midjourney_discord_proxy_url
123
  midjourney_temp_folder = config.get("midjourney_temp_folder", "")
124
  os.environ["MIDJOURNEY_TEMP_FOLDER"] = midjourney_temp_folder
125
 
 
 
 
 
 
 
 
126
  load_config_to_environ(["openai_api_type", "azure_openai_api_key", "azure_openai_api_base_url",
127
  "azure_openai_api_version", "azure_deployment_name", "azure_embedding_deployment_name", "azure_embedding_model_name"])
128
 
 
123
  midjourney_temp_folder = config.get("midjourney_temp_folder", "")
124
  os.environ["MIDJOURNEY_TEMP_FOLDER"] = midjourney_temp_folder
125
 
126
+ spark_api_key = config.get("spark_api_key", "")
127
+ os.environ["SPARK_API_KEY"] = spark_api_key
128
+ spark_appid = config.get("spark_appid", "")
129
+ os.environ["SPARK_APPID"] = spark_appid
130
+ spark_api_secret = config.get("spark_api_secret", "")
131
+ os.environ["SPARK_API_SECRET"] = spark_api_secret
132
+
133
  load_config_to_environ(["openai_api_type", "azure_openai_api_key", "azure_openai_api_base_url",
134
  "azure_openai_api_version", "azure_deployment_name", "azure_embedding_deployment_name", "azure_embedding_model_name"])
135
 
modules/models/base_model.py CHANGED
@@ -142,6 +142,7 @@ class ModelType(Enum):
142
  GooglePaLM = 9
143
  LangchainChat = 10
144
  Midjourney = 11
 
145
 
146
  @classmethod
147
  def get_type(cls, model_name: str):
@@ -171,6 +172,8 @@ class ModelType(Enum):
171
  model_type = ModelType.Midjourney
172
  elif "azure" in model_name_lower or "api" in model_name_lower:
173
  model_type = ModelType.LangchainChat
 
 
174
  else:
175
  model_type = ModelType.Unknown
176
  return model_type
@@ -269,9 +272,12 @@ class BaseLLMModel:
269
  if display_append:
270
  display_append = '\n\n<hr class="append-display no-in-raw" />' + display_append
271
  partial_text = ""
 
272
  for partial_text in stream_iter:
 
 
273
  chatbot[-1] = (chatbot[-1][0], partial_text + display_append)
274
- self.all_token_counts[-1] += 1
275
  status_text = self.token_message()
276
  yield get_return_value()
277
  if self.interrupted:
 
142
  GooglePaLM = 9
143
  LangchainChat = 10
144
  Midjourney = 11
145
+ Spark = 12
146
 
147
  @classmethod
148
  def get_type(cls, model_name: str):
 
172
  model_type = ModelType.Midjourney
173
  elif "azure" in model_name_lower or "api" in model_name_lower:
174
  model_type = ModelType.LangchainChat
175
+ elif "星火大模型" in model_name_lower:
176
+ model_type = ModelType.Spark
177
  else:
178
  model_type = ModelType.Unknown
179
  return model_type
 
272
  if display_append:
273
  display_append = '\n\n<hr class="append-display no-in-raw" />' + display_append
274
  partial_text = ""
275
+ token_increment = 1
276
  for partial_text in stream_iter:
277
+ if type(partial_text) == tuple:
278
+ partial_text, token_increment = partial_text
279
  chatbot[-1] = (chatbot[-1][0], partial_text + display_append)
280
+ self.all_token_counts[-1] += token_increment
281
  status_text = self.token_message()
282
  yield get_return_value()
283
  if self.interrupted:
modules/models/models.py CHANGED
@@ -625,6 +625,9 @@ def get_model(
625
  from .midjourney import Midjourney_Client
626
  mj_proxy_api_secret = os.getenv("MIDJOURNEY_PROXY_API_SECRET")
627
  model = Midjourney_Client(model_name, mj_proxy_api_secret, user_name=user_name)
 
 
 
628
  elif model_type == ModelType.Unknown:
629
  raise ValueError(f"未知模型: {model_name}")
630
  logging.info(msg)
 
625
  from .midjourney import Midjourney_Client
626
  mj_proxy_api_secret = os.getenv("MIDJOURNEY_PROXY_API_SECRET")
627
  model = Midjourney_Client(model_name, mj_proxy_api_secret, user_name=user_name)
628
+ elif model_type == ModelType.Spark:
629
+ from .spark import Spark_Client
630
+ model = Spark_Client(model_name, os.getenv("SPARK_APPID"), os.getenv("SPARK_API_KEY"), os.getenv("SPARK_API_SECRET"), user_name=user_name)
631
  elif model_type == ModelType.Unknown:
632
  raise ValueError(f"未知模型: {model_name}")
633
  logging.info(msg)
modules/presets.py CHANGED
@@ -69,7 +69,9 @@ ONLINE_MODELS = [
69
  "yuanai-1.0-rhythm_poems",
70
  "minimax-abab4-chat",
71
  "minimax-abab5-chat",
72
- "midjourney"
 
 
73
  ]
74
 
75
  LOCAL_MODELS = [
 
69
  "yuanai-1.0-rhythm_poems",
70
  "minimax-abab4-chat",
71
  "minimax-abab5-chat",
72
+ "midjourney",
73
+ "讯飞星火大模型V2.0",
74
+ "讯飞星火大模型V1.5"
75
  ]
76
 
77
  LOCAL_MODELS = [
requirements.txt CHANGED
@@ -27,3 +27,4 @@ google-api-python-client
27
  tabulate
28
  ujson
29
  python-docx
 
 
27
  tabulate
28
  ujson
29
  python-docx
30
+ websocket_client