Tuchuanhuhuhu commited on
Commit
a6ebff0
·
1 Parent(s): a27db7d

feat: 加入DALLE3支持

Browse files
modules/models/DALLE3.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ import openai
4
+ from openai import OpenAI
5
+ from .base_model import BaseLLMModel
6
+ from .. import shared
7
+ from ..config import retrieve_proxy
8
+
9
+
10
+ class OpenAI_DALLE3_Client(BaseLLMModel):
11
+ def __init__(self, model_name, api_key, user_name="") -> None:
12
+ super().__init__(model_name=model_name, user=user_name)
13
+ self.api_key = api_key
14
+
15
+ def _get_dalle3_prompt(self):
16
+ prompt = self.history[-1]["content"]
17
+ if prompt.endswith("--raw"):
18
+ prompt = "I NEED to test how the tool works with extremely simple prompts. DO NOT add any detail, just use it AS-IS:" + prompt
19
+ return prompt
20
+
21
+ @shared.state.switching_api_key
22
+ def get_answer_at_once(self):
23
+ prompt = self._get_dalle3_prompt()
24
+ with retrieve_proxy():
25
+ client = OpenAI(api_key=openai.api_key)
26
+ try:
27
+ response = client.images.generate(
28
+ model="dall-e-3",
29
+ prompt=prompt,
30
+ size="1024x1024",
31
+ quality="standard",
32
+ n=1,
33
+ )
34
+ except openai.BadRequestError as e:
35
+ msg = str(e)
36
+ match = re.search(r"'message': '([^']*)'", msg)
37
+ return match.group(1), 0
38
+ return f'<img src="{response.data[0].url}"> {response.data[0].revised_prompt}', 0
modules/models/base_model.py CHANGED
@@ -153,6 +153,7 @@ class ModelType(Enum):
153
  Qwen = 15
154
  OpenAIVision = 16
155
  ERNIE = 17
 
156
 
157
  @classmethod
158
  def get_type(cls, model_name: str):
@@ -195,6 +196,8 @@ class ModelType(Enum):
195
  model_type = ModelType.Qwen
196
  elif "ernie" in model_name_lower:
197
  model_type = ModelType.ERNIE
 
 
198
  else:
199
  model_type = ModelType.LLaMA
200
  return model_type
 
153
  Qwen = 15
154
  OpenAIVision = 16
155
  ERNIE = 17
156
+ DALLE3 = 18
157
 
158
  @classmethod
159
  def get_type(cls, model_name: str):
 
196
  model_type = ModelType.Qwen
197
  elif "ernie" in model_name_lower:
198
  model_type = ModelType.ERNIE
199
+ elif "dall" in model_name_lower:
200
+ model_type = ModelType.DALLE3
201
  else:
202
  model_type = ModelType.LLaMA
203
  return model_type
modules/models/models.py CHANGED
@@ -129,6 +129,10 @@ def get_model(
129
  elif model_type == ModelType.ERNIE:
130
  from .ERNIE import ERNIE_Client
131
  model = ERNIE_Client(model_name, api_key=os.getenv("ERNIE_APIKEY"),secret_key=os.getenv("ERNIE_SECRETKEY"))
 
 
 
 
132
  elif model_type == ModelType.Unknown:
133
  raise ValueError(f"未知模型: {model_name}")
134
  logging.info(msg)
 
129
  elif model_type == ModelType.ERNIE:
130
  from .ERNIE import ERNIE_Client
131
  model = ERNIE_Client(model_name, api_key=os.getenv("ERNIE_APIKEY"),secret_key=os.getenv("ERNIE_SECRETKEY"))
132
+ elif model_type == ModelType.DALLE3:
133
+ from .DALLE3 import OpenAI_DALLE3_Client
134
+ access_key = os.environ.get("OPENAI_API_KEY", access_key)
135
+ model = OpenAI_DALLE3_Client(model_name, api_key=access_key, user_name=user_name)
136
  elif model_type == ModelType.Unknown:
137
  raise ValueError(f"未知模型: {model_name}")
138
  logging.info(msg)
modules/presets.py CHANGED
@@ -62,6 +62,7 @@ ONLINE_MODELS = [
62
  "GPT4 Vision",
63
  "川虎助理",
64
  "川虎助理 Pro",
 
65
  "GooglePaLM",
66
  "xmchat",
67
  "Azure OpenAI",
 
62
  "GPT4 Vision",
63
  "川虎助理",
64
  "川虎助理 Pro",
65
+ "DALL-E 3",
66
  "GooglePaLM",
67
  "xmchat",
68
  "Azure OpenAI",