Spaces:
Sleeping
Sleeping
Tuchuanhuhuhu
commited on
Commit
·
a6ebff0
1
Parent(s):
a27db7d
feat: 加入DALLE3支持
Browse files- modules/models/DALLE3.py +38 -0
- modules/models/base_model.py +3 -0
- modules/models/models.py +4 -0
- modules/presets.py +1 -0
modules/models/DALLE3.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import json
|
3 |
+
import openai
|
4 |
+
from openai import OpenAI
|
5 |
+
from .base_model import BaseLLMModel
|
6 |
+
from .. import shared
|
7 |
+
from ..config import retrieve_proxy
|
8 |
+
|
9 |
+
|
10 |
+
class OpenAI_DALLE3_Client(BaseLLMModel):
|
11 |
+
def __init__(self, model_name, api_key, user_name="") -> None:
|
12 |
+
super().__init__(model_name=model_name, user=user_name)
|
13 |
+
self.api_key = api_key
|
14 |
+
|
15 |
+
def _get_dalle3_prompt(self):
|
16 |
+
prompt = self.history[-1]["content"]
|
17 |
+
if prompt.endswith("--raw"):
|
18 |
+
prompt = "I NEED to test how the tool works with extremely simple prompts. DO NOT add any detail, just use it AS-IS:" + prompt
|
19 |
+
return prompt
|
20 |
+
|
21 |
+
@shared.state.switching_api_key
|
22 |
+
def get_answer_at_once(self):
|
23 |
+
prompt = self._get_dalle3_prompt()
|
24 |
+
with retrieve_proxy():
|
25 |
+
client = OpenAI(api_key=openai.api_key)
|
26 |
+
try:
|
27 |
+
response = client.images.generate(
|
28 |
+
model="dall-e-3",
|
29 |
+
prompt=prompt,
|
30 |
+
size="1024x1024",
|
31 |
+
quality="standard",
|
32 |
+
n=1,
|
33 |
+
)
|
34 |
+
except openai.BadRequestError as e:
|
35 |
+
msg = str(e)
|
36 |
+
match = re.search(r"'message': '([^']*)'", msg)
|
37 |
+
return match.group(1), 0
|
38 |
+
return f'<img src="{response.data[0].url}"> {response.data[0].revised_prompt}', 0
|
modules/models/base_model.py
CHANGED
@@ -153,6 +153,7 @@ class ModelType(Enum):
|
|
153 |
Qwen = 15
|
154 |
OpenAIVision = 16
|
155 |
ERNIE = 17
|
|
|
156 |
|
157 |
@classmethod
|
158 |
def get_type(cls, model_name: str):
|
@@ -195,6 +196,8 @@ class ModelType(Enum):
|
|
195 |
model_type = ModelType.Qwen
|
196 |
elif "ernie" in model_name_lower:
|
197 |
model_type = ModelType.ERNIE
|
|
|
|
|
198 |
else:
|
199 |
model_type = ModelType.LLaMA
|
200 |
return model_type
|
|
|
153 |
Qwen = 15
|
154 |
OpenAIVision = 16
|
155 |
ERNIE = 17
|
156 |
+
DALLE3 = 18
|
157 |
|
158 |
@classmethod
|
159 |
def get_type(cls, model_name: str):
|
|
|
196 |
model_type = ModelType.Qwen
|
197 |
elif "ernie" in model_name_lower:
|
198 |
model_type = ModelType.ERNIE
|
199 |
+
elif "dall" in model_name_lower:
|
200 |
+
model_type = ModelType.DALLE3
|
201 |
else:
|
202 |
model_type = ModelType.LLaMA
|
203 |
return model_type
|
modules/models/models.py
CHANGED
@@ -129,6 +129,10 @@ def get_model(
|
|
129 |
elif model_type == ModelType.ERNIE:
|
130 |
from .ERNIE import ERNIE_Client
|
131 |
model = ERNIE_Client(model_name, api_key=os.getenv("ERNIE_APIKEY"),secret_key=os.getenv("ERNIE_SECRETKEY"))
|
|
|
|
|
|
|
|
|
132 |
elif model_type == ModelType.Unknown:
|
133 |
raise ValueError(f"未知模型: {model_name}")
|
134 |
logging.info(msg)
|
|
|
129 |
elif model_type == ModelType.ERNIE:
|
130 |
from .ERNIE import ERNIE_Client
|
131 |
model = ERNIE_Client(model_name, api_key=os.getenv("ERNIE_APIKEY"),secret_key=os.getenv("ERNIE_SECRETKEY"))
|
132 |
+
elif model_type == ModelType.DALLE3:
|
133 |
+
from .DALLE3 import OpenAI_DALLE3_Client
|
134 |
+
access_key = os.environ.get("OPENAI_API_KEY", access_key)
|
135 |
+
model = OpenAI_DALLE3_Client(model_name, api_key=access_key, user_name=user_name)
|
136 |
elif model_type == ModelType.Unknown:
|
137 |
raise ValueError(f"未知模型: {model_name}")
|
138 |
logging.info(msg)
|
modules/presets.py
CHANGED
@@ -62,6 +62,7 @@ ONLINE_MODELS = [
|
|
62 |
"GPT4 Vision",
|
63 |
"川虎助理",
|
64 |
"川虎助理 Pro",
|
|
|
65 |
"GooglePaLM",
|
66 |
"xmchat",
|
67 |
"Azure OpenAI",
|
|
|
62 |
"GPT4 Vision",
|
63 |
"川虎助理",
|
64 |
"川虎助理 Pro",
|
65 |
+
"DALL-E 3",
|
66 |
"GooglePaLM",
|
67 |
"xmchat",
|
68 |
"Azure OpenAI",
|