Spaces:
Sleeping
Sleeping
Tuchuanhuhuhu
commited on
Commit
·
12eb16f
1
Parent(s):
c407bb3
feat: 加入讯飞星火大模型支持 #877
Browse files- config_example.json +3 -0
- modules/config.py +7 -0
- modules/models/base_model.py +7 -1
- modules/models/models.py +3 -0
- modules/presets.py +3 -1
- requirements.txt +1 -0
config_example.json
CHANGED
@@ -11,6 +11,9 @@
|
|
11 |
"midjourney_proxy_api_secret": "", // 你的 MidJourney Proxy API Secret,用于鉴权访问 api,可选
|
12 |
"midjourney_discord_proxy_url": "", // 你的 MidJourney Discord Proxy URL,用于对生成对图进行反代,可选
|
13 |
"midjourney_temp_folder": "./tmp", // 你的 MidJourney 临时文件夹,用于存放生成的图片,填空则关闭自动下载切图(直接显示MJ的四宫格图)
|
|
|
|
|
|
|
14 |
|
15 |
|
16 |
//== Azure ==
|
|
|
11 |
"midjourney_proxy_api_secret": "", // 你的 MidJourney Proxy API Secret,用于鉴权访问 api,可选
|
12 |
"midjourney_discord_proxy_url": "", // 你的 MidJourney Discord Proxy URL,用于对生成对图进行反代,可选
|
13 |
"midjourney_temp_folder": "./tmp", // 你的 MidJourney 临时文件夹,用于存放生成的图片,填空则关闭自动下载切图(直接显示MJ的四宫格图)
|
14 |
+
"spark_appid": "", // 你的 讯飞星火大模型 API AppID,用于讯飞星火大模型对话模型
|
15 |
+
"spark_api_key": "", // 你的 讯飞星火大模型 API Key,用于讯飞星火大模型对话模型
|
16 |
+
"spark_api_secret": "", // 你的 讯飞星火大模型 API Secret,用于讯飞星火大模型对话模型
|
17 |
|
18 |
|
19 |
//== Azure ==
|
modules/config.py
CHANGED
@@ -123,6 +123,13 @@ os.environ["MIDJOURNEY_DISCORD_PROXY_URL"] = midjourney_discord_proxy_url
|
|
123 |
midjourney_temp_folder = config.get("midjourney_temp_folder", "")
|
124 |
os.environ["MIDJOURNEY_TEMP_FOLDER"] = midjourney_temp_folder
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
load_config_to_environ(["openai_api_type", "azure_openai_api_key", "azure_openai_api_base_url",
|
127 |
"azure_openai_api_version", "azure_deployment_name", "azure_embedding_deployment_name", "azure_embedding_model_name"])
|
128 |
|
|
|
123 |
midjourney_temp_folder = config.get("midjourney_temp_folder", "")
|
124 |
os.environ["MIDJOURNEY_TEMP_FOLDER"] = midjourney_temp_folder
|
125 |
|
126 |
+
spark_api_key = config.get("spark_api_key", "")
|
127 |
+
os.environ["SPARK_API_KEY"] = spark_api_key
|
128 |
+
spark_appid = config.get("spark_appid", "")
|
129 |
+
os.environ["SPARK_APPID"] = spark_appid
|
130 |
+
spark_api_secret = config.get("spark_api_secret", "")
|
131 |
+
os.environ["SPARK_API_SECRET"] = spark_api_secret
|
132 |
+
|
133 |
load_config_to_environ(["openai_api_type", "azure_openai_api_key", "azure_openai_api_base_url",
|
134 |
"azure_openai_api_version", "azure_deployment_name", "azure_embedding_deployment_name", "azure_embedding_model_name"])
|
135 |
|
modules/models/base_model.py
CHANGED
@@ -142,6 +142,7 @@ class ModelType(Enum):
|
|
142 |
GooglePaLM = 9
|
143 |
LangchainChat = 10
|
144 |
Midjourney = 11
|
|
|
145 |
|
146 |
@classmethod
|
147 |
def get_type(cls, model_name: str):
|
@@ -171,6 +172,8 @@ class ModelType(Enum):
|
|
171 |
model_type = ModelType.Midjourney
|
172 |
elif "azure" in model_name_lower or "api" in model_name_lower:
|
173 |
model_type = ModelType.LangchainChat
|
|
|
|
|
174 |
else:
|
175 |
model_type = ModelType.Unknown
|
176 |
return model_type
|
@@ -269,9 +272,12 @@ class BaseLLMModel:
|
|
269 |
if display_append:
|
270 |
display_append = '\n\n<hr class="append-display no-in-raw" />' + display_append
|
271 |
partial_text = ""
|
|
|
272 |
for partial_text in stream_iter:
|
|
|
|
|
273 |
chatbot[-1] = (chatbot[-1][0], partial_text + display_append)
|
274 |
-
self.all_token_counts[-1] +=
|
275 |
status_text = self.token_message()
|
276 |
yield get_return_value()
|
277 |
if self.interrupted:
|
|
|
142 |
GooglePaLM = 9
|
143 |
LangchainChat = 10
|
144 |
Midjourney = 11
|
145 |
+
Spark = 12
|
146 |
|
147 |
@classmethod
|
148 |
def get_type(cls, model_name: str):
|
|
|
172 |
model_type = ModelType.Midjourney
|
173 |
elif "azure" in model_name_lower or "api" in model_name_lower:
|
174 |
model_type = ModelType.LangchainChat
|
175 |
+
elif "星火大模型" in model_name_lower:
|
176 |
+
model_type = ModelType.Spark
|
177 |
else:
|
178 |
model_type = ModelType.Unknown
|
179 |
return model_type
|
|
|
272 |
if display_append:
|
273 |
display_append = '\n\n<hr class="append-display no-in-raw" />' + display_append
|
274 |
partial_text = ""
|
275 |
+
token_increment = 1
|
276 |
for partial_text in stream_iter:
|
277 |
+
if type(partial_text) == tuple:
|
278 |
+
partial_text, token_increment = partial_text
|
279 |
chatbot[-1] = (chatbot[-1][0], partial_text + display_append)
|
280 |
+
self.all_token_counts[-1] += token_increment
|
281 |
status_text = self.token_message()
|
282 |
yield get_return_value()
|
283 |
if self.interrupted:
|
modules/models/models.py
CHANGED
@@ -625,6 +625,9 @@ def get_model(
|
|
625 |
from .midjourney import Midjourney_Client
|
626 |
mj_proxy_api_secret = os.getenv("MIDJOURNEY_PROXY_API_SECRET")
|
627 |
model = Midjourney_Client(model_name, mj_proxy_api_secret, user_name=user_name)
|
|
|
|
|
|
|
628 |
elif model_type == ModelType.Unknown:
|
629 |
raise ValueError(f"未知模型: {model_name}")
|
630 |
logging.info(msg)
|
|
|
625 |
from .midjourney import Midjourney_Client
|
626 |
mj_proxy_api_secret = os.getenv("MIDJOURNEY_PROXY_API_SECRET")
|
627 |
model = Midjourney_Client(model_name, mj_proxy_api_secret, user_name=user_name)
|
628 |
+
elif model_type == ModelType.Spark:
|
629 |
+
from .spark import Spark_Client
|
630 |
+
model = Spark_Client(model_name, os.getenv("SPARK_APPID"), os.getenv("SPARK_API_KEY"), os.getenv("SPARK_API_SECRET"), user_name=user_name)
|
631 |
elif model_type == ModelType.Unknown:
|
632 |
raise ValueError(f"未知模型: {model_name}")
|
633 |
logging.info(msg)
|
modules/presets.py
CHANGED
@@ -69,7 +69,9 @@ ONLINE_MODELS = [
|
|
69 |
"yuanai-1.0-rhythm_poems",
|
70 |
"minimax-abab4-chat",
|
71 |
"minimax-abab5-chat",
|
72 |
-
"midjourney"
|
|
|
|
|
73 |
]
|
74 |
|
75 |
LOCAL_MODELS = [
|
|
|
69 |
"yuanai-1.0-rhythm_poems",
|
70 |
"minimax-abab4-chat",
|
71 |
"minimax-abab5-chat",
|
72 |
+
"midjourney",
|
73 |
+
"讯飞星火大模型V2.0",
|
74 |
+
"讯飞星火大模型V1.5"
|
75 |
]
|
76 |
|
77 |
LOCAL_MODELS = [
|
requirements.txt
CHANGED
@@ -27,3 +27,4 @@ google-api-python-client
|
|
27 |
tabulate
|
28 |
ujson
|
29 |
python-docx
|
|
|
|
27 |
tabulate
|
28 |
ujson
|
29 |
python-docx
|
30 |
+
websocket_client
|