Spaces:
Sleeping
Sleeping
update
Browse files- app - コピー.py +669 -0
- app.py +8 -9
- gradio_chat_image.py → gradio_chat_image - コピー.py +5 -4
app - コピー.py
ADDED
@@ -0,0 +1,669 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import base64
|
4 |
+
import re
|
5 |
+
import time
|
6 |
+
import datetime
|
7 |
+
from zoneinfo import ZoneInfo
|
8 |
+
from PIL import Image
|
9 |
+
from io import BytesIO
|
10 |
+
import shutil
|
11 |
+
import gradio as gr
|
12 |
+
from openai import (
|
13 |
+
OpenAI, AuthenticationError, NotFoundError, BadRequestError
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
# GPT用設定
|
18 |
+
SYS_PROMPT_DEFAULT = "あなたは優秀なアシスタントです。回答は日本語でお願いします。"
|
19 |
+
DUMMY = "********************"
|
20 |
+
file_format = {".png", ".jpeg", ".jpg", ".webp", ".gif", ".PNG", ".JPEG", ".JPG", ".WEBP", ".GIF"}
|
21 |
+
|
22 |
+
# 各種出力フォルダ
|
23 |
+
IMG_FOLDER = "sample_data" #"images"
|
24 |
+
|
25 |
+
# 各種メッセージ
|
26 |
+
PLACEHOLDER = ""
|
27 |
+
# IMG_MSG = "(画像ファイルを追加しました。リセットボタンの上に表示されています。)"
|
28 |
+
ANT_MSG = "(下部の[出力ファイル]にファイルを追加しました。)"
|
29 |
+
|
30 |
+
# 各種設定値
|
31 |
+
MAX_TRIAL = int(os.environ["MAX_TRIAL"]) # メッセージ取得最大試行数
|
32 |
+
INTER_SEC = int(os.environ["INTER_SEC"]) # 試行間隔(秒)
|
33 |
+
MAX_TOKENS = int(os.environ["MAX_TOKENS"]) # Vison最大トークン
|
34 |
+
|
35 |
+
# 正規表現用パターン
|
36 |
+
pt = r".*!\[(.*)\]\((.*)\)"
|
37 |
+
|
38 |
+
# サンプル用情報
|
39 |
+
examples = ["1980s anime girl with straight bob-cut in school uniform, roughly drawn drawing"
|
40 |
+
, "a minimalisit logo for a sporting goods company"]
|
41 |
+
|
42 |
+
|
43 |
+
# 各関数定義
|
44 |
+
def set_state(openai_key, size, quality, detail, state):
|
45 |
+
""" 設定タブの情報をセッションに保存する関数 """
|
46 |
+
|
47 |
+
state["openai_key"] = openai_key
|
48 |
+
state["size"] = size
|
49 |
+
state["quality"] = quality
|
50 |
+
state["detail"] = detail
|
51 |
+
|
52 |
+
return state
|
53 |
+
|
54 |
+
|
55 |
+
def init(state, text, image):
|
56 |
+
""" 入力チェックを行う関数 """
|
57 |
+
""" ※ここで例外を起こすと入力できなくなるので次の関数でエラーにする """
|
58 |
+
|
59 |
+
err_msg = ""
|
60 |
+
|
61 |
+
print(state)
|
62 |
+
|
63 |
+
# if state["openai_key"] == "" or state["openai_key"] is None:
|
64 |
+
|
65 |
+
# # OpenAI API Key未入力
|
66 |
+
# err_msg = "OpenAI API Keyを入力してください。(設定タブ)"
|
67 |
+
|
68 |
+
if not text:
|
69 |
+
|
70 |
+
# テキスト未入力
|
71 |
+
err_msg = "テキストを入力して下さい。"
|
72 |
+
|
73 |
+
return state, err_msg
|
74 |
+
|
75 |
+
elif image:
|
76 |
+
|
77 |
+
# 入力画像のファイル形式チェック
|
78 |
+
root, ext = os.path.splitext(image)
|
79 |
+
|
80 |
+
print(ext, file_format)
|
81 |
+
|
82 |
+
if ext not in file_format:
|
83 |
+
|
84 |
+
# ファイル形式チェック
|
85 |
+
err_msg = "指定した形式のファイルをアップしてください。(注意事項タブに記載)"
|
86 |
+
|
87 |
+
return state, err_msg
|
88 |
+
|
89 |
+
try:
|
90 |
+
|
91 |
+
if state["client"] is None:
|
92 |
+
|
93 |
+
# 初回起動時は初期処理をする
|
94 |
+
os.environ["OPENAI_API_KEY"] = os.environ["TEST_OPENAI_KEY"] # テスト時
|
95 |
+
# os.environ["OPENAI_API_KEY"] = state["openai_key"]
|
96 |
+
|
97 |
+
# クライアント新規作成
|
98 |
+
client = OpenAI()
|
99 |
+
|
100 |
+
# client作成後は消す
|
101 |
+
os.environ["OPENAI_API_KEY"] = ""
|
102 |
+
|
103 |
+
# セッションにセット
|
104 |
+
state["client"] = client
|
105 |
+
|
106 |
+
else:
|
107 |
+
|
108 |
+
# 既存のクライアントをセット
|
109 |
+
client = state["client"]
|
110 |
+
|
111 |
+
|
112 |
+
if state["thread_id"] == "":
|
113 |
+
|
114 |
+
# スレッド作成
|
115 |
+
thread = client.beta.threads.create()
|
116 |
+
|
117 |
+
state["thread_id"] = thread.id
|
118 |
+
|
119 |
+
|
120 |
+
if state["assistant_id"] == "":
|
121 |
+
|
122 |
+
# アシスタント作成
|
123 |
+
# assistant = client.beta.assistants.create(
|
124 |
+
# name="codeinter_test",
|
125 |
+
# instructions=state["system_prompt"],
|
126 |
+
# # model="gpt-4-1106-preview",
|
127 |
+
# model="gpt-3.5-turbo-1106",
|
128 |
+
# tools=[{"type": "code_interpreter"}]
|
129 |
+
# )
|
130 |
+
# state["assistant_id"] = assistant.id
|
131 |
+
|
132 |
+
state["assistant_id"] = os.environ["ASSIST_ID"] # テスト中アシスタントは固定
|
133 |
+
|
134 |
+
else:
|
135 |
+
|
136 |
+
# アシスタント確認(IDが存在しないならエラーとなる)
|
137 |
+
assistant = client.beta.assistants.retrieve(state["assistant_id"])
|
138 |
+
|
139 |
+
except NotFoundError as e:
|
140 |
+
err_msg = "アシスタントIDが間違っています。新しく作成する場合はアシスタントIDを空欄にして下さい。"
|
141 |
+
except AuthenticationError as e:
|
142 |
+
err_msg = "認証エラーとなりました。OpenAPIKeyが正しいか、支払い方法などが設定されているか確認して下さい。"
|
143 |
+
except Exception as e:
|
144 |
+
err_msg = "その他のエラーが発生しました。"
|
145 |
+
print(e)
|
146 |
+
finally:
|
147 |
+
return state, err_msg
|
148 |
+
|
149 |
+
|
150 |
+
def raise_exception(err_msg):
|
151 |
+
""" エラーの場合例外を起こす関数 """
|
152 |
+
|
153 |
+
if err_msg != "":
|
154 |
+
raise Exception("これは入力チェックでの例外です。")
|
155 |
+
|
156 |
+
return
|
157 |
+
|
158 |
+
|
159 |
+
def add_history(history, text, image):
|
160 |
+
""" Chat履歴"history"に追加を行う関数 """
|
161 |
+
|
162 |
+
err_msg = ""
|
163 |
+
|
164 |
+
if image is None or image == "":
|
165 |
+
|
166 |
+
# テキストだけの場合そのまま追加
|
167 |
+
history = history + [(text, None)]
|
168 |
+
|
169 |
+
elif image is not None:
|
170 |
+
|
171 |
+
# 画像があれば画像とテキストを追加
|
172 |
+
history = history + [((image,), DUMMY)]
|
173 |
+
history = history + [(text, None)]
|
174 |
+
|
175 |
+
# テキストは利用不可・初期化し、画像は利用不可に
|
176 |
+
update_text = gr.update(value="", placeholder = "",interactive=False)
|
177 |
+
update_file = gr.update(interactive=False)
|
178 |
+
|
179 |
+
return history, update_text, update_file, err_msg
|
180 |
+
|
181 |
+
|
182 |
+
def bot(state, history, image_path):
|
183 |
+
|
184 |
+
err_msg = ""
|
185 |
+
out_image_path = None
|
186 |
+
image_preview = False
|
187 |
+
|
188 |
+
ant_file = None
|
189 |
+
|
190 |
+
# セッション情報取得
|
191 |
+
client = state["client"]
|
192 |
+
assistant_id = state["assistant_id"]
|
193 |
+
thread_id = state["thread_id"]
|
194 |
+
last_msg_id = state["last_msg_id"]
|
195 |
+
history_outputs = state["tool_outputs"]
|
196 |
+
|
197 |
+
# メッセージ設定
|
198 |
+
message = client.beta.threads.messages.create(
|
199 |
+
thread_id=thread_id,
|
200 |
+
role="user",
|
201 |
+
content=history[-1][0],
|
202 |
+
)
|
203 |
+
|
204 |
+
# RUNスタート
|
205 |
+
run = client.beta.threads.runs.create(
|
206 |
+
thread_id=thread_id,
|
207 |
+
assistant_id=assistant_id,
|
208 |
+
# instructions=system_prompt
|
209 |
+
)
|
210 |
+
|
211 |
+
# "completed"となるまで繰り返す(指定秒おき)
|
212 |
+
for i in range(0, MAX_TRIAL, 1):
|
213 |
+
|
214 |
+
if i > 0:
|
215 |
+
time.sleep(INTER_SEC)
|
216 |
+
|
217 |
+
# 変数初期化
|
218 |
+
tool_outputs = []
|
219 |
+
|
220 |
+
# メッセージ受け取り
|
221 |
+
run = client.beta.threads.runs.retrieve(
|
222 |
+
thread_id=thread_id,
|
223 |
+
run_id=run.id
|
224 |
+
)
|
225 |
+
|
226 |
+
print(run.status)
|
227 |
+
|
228 |
+
if run.status == "requires_action": # 関数の結果の待ちの場合
|
229 |
+
|
230 |
+
print(run.required_action)
|
231 |
+
|
232 |
+
# tool_callsの各項目取得
|
233 |
+
tool_calls = run.required_action.submit_tool_outputs.tool_calls
|
234 |
+
|
235 |
+
print(len(tool_calls))
|
236 |
+
print(tool_calls)
|
237 |
+
|
238 |
+
# 一つ目だけ取得
|
239 |
+
tool_id = tool_calls[0].id
|
240 |
+
func_name = tool_calls[0].function.name
|
241 |
+
func_args = json.loads(tool_calls[0].function.arguments)
|
242 |
+
|
243 |
+
if func_name == "request_DallE3":
|
244 |
+
|
245 |
+
# ファイル名は現在時刻に
|
246 |
+
dt = datetime.datetime.now(ZoneInfo("Asia/Tokyo"))
|
247 |
+
image_name = dt.strftime("%Y%m%d%H%M%S") + ".png"
|
248 |
+
|
249 |
+
# ファイルパスは手動設定(誤りがないように)
|
250 |
+
out_image_path = IMG_FOLDER + "/" + image_name
|
251 |
+
|
252 |
+
# dall-e3のとき"image_path"は出力ファイルパス
|
253 |
+
func_args["image_path"] = out_image_path
|
254 |
+
|
255 |
+
elif func_name == "request_Vision":
|
256 |
+
|
257 |
+
if image_path is None:
|
258 |
+
|
259 |
+
# 画像がない場合エラーとなるようにする
|
260 |
+
func_args["image_path"] = ""
|
261 |
+
|
262 |
+
else:
|
263 |
+
|
264 |
+
# ファイルパスは手動設定
|
265 |
+
func_args["image_path"] = image_path
|
266 |
+
|
267 |
+
else:
|
268 |
+
|
269 |
+
# 関数名がないなら次へ
|
270 |
+
continue
|
271 |
+
|
272 |
+
# 関数を実行
|
273 |
+
func_output = func_action(state, func_name, func_args)
|
274 |
+
|
275 |
+
print(func_output)
|
276 |
+
|
277 |
+
# tool_outputリストに追加
|
278 |
+
tool_outputs.append({"tool_call_id": tool_id, "output": func_output})
|
279 |
+
|
280 |
+
# 複数の関数が必要な場合
|
281 |
+
if len(tool_calls) > 1:
|
282 |
+
|
283 |
+
# for i in range(len(tool_calls) - 1):
|
284 |
+
for i, tool_call in enumerate(tool_calls):
|
285 |
+
|
286 |
+
if i > 0:
|
287 |
+
# print(history_outputs[-(i+1)])
|
288 |
+
# # 最新のものからセット
|
289 |
+
# tool_outputs.append({"tool_call_id": history_outputs[-(i+1)]["tool_call_id"], "output": history_outputs[-(i+1)]["output"]})
|
290 |
+
|
291 |
+
# ダミー をセットする
|
292 |
+
tool_outputs.append({"tool_call_id": tool_calls.id, "output": {"answer" : ""}})
|
293 |
+
|
294 |
+
print(tool_outputs)
|
295 |
+
|
296 |
+
# 関数の出力を提出
|
297 |
+
run = client.beta.threads.runs.submit_tool_outputs(
|
298 |
+
thread_id=thread_id,
|
299 |
+
run_id=run.id,
|
300 |
+
tool_outputs=tool_outputs
|
301 |
+
)
|
302 |
+
|
303 |
+
if func_name == "request_DallE3":
|
304 |
+
|
305 |
+
# 画像の表示をする
|
306 |
+
image_preview = True
|
307 |
+
|
308 |
+
# セッション更新
|
309 |
+
history_outputs += tool_outputs
|
310 |
+
state["tool_outputs"] = history_outputs
|
311 |
+
|
312 |
+
else:
|
313 |
+
|
314 |
+
# 前回のメッセージより後を昇順で取り出す
|
315 |
+
messages = client.beta.threads.messages.list(
|
316 |
+
thread_id=thread_id,
|
317 |
+
after=last_msg_id,
|
318 |
+
order="asc"
|
319 |
+
)
|
320 |
+
|
321 |
+
print(messages)
|
322 |
+
|
323 |
+
# messageを取り出す
|
324 |
+
for msg in messages:
|
325 |
+
|
326 |
+
if msg.role == "assistant":
|
327 |
+
|
328 |
+
for content in msg.content:
|
329 |
+
|
330 |
+
res_text = ""
|
331 |
+
file_id = ""
|
332 |
+
|
333 |
+
cont_dict = content.model_dump() # 辞書型に変換
|
334 |
+
|
335 |
+
# 返答テキスト取得
|
336 |
+
res_text = cont_dict["text"].get("value")
|
337 |
+
|
338 |
+
if res_text != "":
|
339 |
+
|
340 |
+
# テキストを変換("sandbox:"などを」消す)
|
341 |
+
result = re.search(pt, res_text)
|
342 |
+
|
343 |
+
if result:
|
344 |
+
|
345 |
+
# パターン一致の場合はプロンプトだけ抜き出す
|
346 |
+
res_text = result.group(1)
|
347 |
+
|
348 |
+
# Chat画面更新
|
349 |
+
if history[-1][1] is not None:
|
350 |
+
|
351 |
+
# 新しい行を追加
|
352 |
+
history = history + [[None, res_text]]
|
353 |
+
|
354 |
+
else:
|
355 |
+
|
356 |
+
history[-1][1] = res_text
|
357 |
+
|
358 |
+
if image_preview:
|
359 |
+
|
360 |
+
print(out_image_path)
|
361 |
+
|
362 |
+
# Functionで画像を取得していた場合表示
|
363 |
+
history = history + [(None, (out_image_path,))]
|
364 |
+
|
365 |
+
image_preview = False
|
366 |
+
|
367 |
+
# 最終メッセージID更新
|
368 |
+
last_msg_id = msg.id
|
369 |
+
|
370 |
+
# Chatbotを返す(labelとhistoryを更新)
|
371 |
+
yield gr.Chatbot(label=run.status ,value=history), out_image_path, ant_file, err_msg
|
372 |
+
|
373 |
+
# セッションのメッセージID更新
|
374 |
+
state["last_msg_id"] = last_msg_id
|
375 |
+
|
376 |
+
# 完了なら終了
|
377 |
+
if run.status == "completed":
|
378 |
+
|
379 |
+
yield gr.Chatbot(label=run.status ,value=history), out_image_path, ant_file, err_msg
|
380 |
+
break
|
381 |
+
|
382 |
+
|
383 |
+
elif run.status == "failed":
|
384 |
+
|
385 |
+
# エラーとして終了
|
386 |
+
err_msg = "※メッセージ取得に失敗しました。"
|
387 |
+
yield gr.Chatbot(label=run.status ,value=history), out_image_path, ant_file, err_msg
|
388 |
+
break
|
389 |
+
|
390 |
+
elif i == MAX_TRIAL:
|
391 |
+
|
392 |
+
# エラーとして終了
|
393 |
+
err_msg = "※メッセージ取得の際にタイムアウトしました。"
|
394 |
+
yield gr.Chatbot(label=run.status ,value=history), out_image_path, ant_file, err_msg
|
395 |
+
break
|
396 |
+
|
397 |
+
else:
|
398 |
+
if i > 3:
|
399 |
+
|
400 |
+
# 作業中とわかるようにする
|
401 |
+
yield gr.Chatbot(label=run.status + " (Request:" + str(i) + ")" ,value=history), out_image_path, ant_file, err_msg
|
402 |
+
|
403 |
+
|
404 |
+
def func_action(state, func_name, func_args):
|
405 |
+
|
406 |
+
# セッションから情報取得
|
407 |
+
client = state["client"]
|
408 |
+
size = state["size"]
|
409 |
+
quality = state["quality"]
|
410 |
+
detail = state["detail"]
|
411 |
+
|
412 |
+
print("name:", func_name)
|
413 |
+
print("arguments:", func_args)
|
414 |
+
|
415 |
+
if func_name == "request_DallE3":
|
416 |
+
|
417 |
+
func_output = request_DallE3(
|
418 |
+
client,
|
419 |
+
func_args["prompt"],
|
420 |
+
size,
|
421 |
+
quality,
|
422 |
+
func_args["image_path"] # 出力パス
|
423 |
+
)
|
424 |
+
|
425 |
+
elif func_name == "request_Vision":
|
426 |
+
|
427 |
+
func_output = request_Vision(
|
428 |
+
client,
|
429 |
+
func_args["prompt"],
|
430 |
+
func_args["image_path"],
|
431 |
+
detail,
|
432 |
+
MAX_TOKENS
|
433 |
+
)
|
434 |
+
|
435 |
+
return func_output
|
436 |
+
|
437 |
+
def finally_proc():
|
438 |
+
""" 最終処理用関数 """
|
439 |
+
|
440 |
+
# テキストを使えるように
|
441 |
+
new_text = gr.update(interactive = True)
|
442 |
+
|
443 |
+
# 画像はリセット
|
444 |
+
new_image = gr.update(value=None, interactive = True)
|
445 |
+
|
446 |
+
return new_text, new_image
|
447 |
+
|
448 |
+
|
449 |
+
def clear_click(state):
|
450 |
+
""" クリアボタンクリック時 """
|
451 |
+
|
452 |
+
# セッションの一部をリセット
|
453 |
+
state["thread_id"] = ""
|
454 |
+
state["last_msg_id"] = ""
|
455 |
+
state["tool_outputs"] = []
|
456 |
+
|
457 |
+
return state
|
458 |
+
|
459 |
+
def make_archive():
|
460 |
+
|
461 |
+
shutil.make_archive("output_image", format='zip', root_dir=IMG_FOLDER)
|
462 |
+
|
463 |
+
return "output_image.zip", "下部の出力ファイルからダウンロードして下さい。"
|
464 |
+
|
465 |
+
def encode_image(image_path):
|
466 |
+
with open(image_path, "rb") as image_file:
|
467 |
+
return base64.b64encode(image_file.read()).decode("utf-8")
|
468 |
+
|
469 |
+
def make_prompt(prompt):
|
470 |
+
|
471 |
+
return "次のプロンプトで画像を作ってください「" + prompt + "」。"
|
472 |
+
|
473 |
+
# 画面構成
|
474 |
+
with gr.Blocks() as demo:
|
475 |
+
|
476 |
+
title = "<h2>GPT画像入出力対応チャット</h2>"
|
477 |
+
message = "<h3>・DallE3の画像生成とGPT-4 with Visionの画像解析が利用できます。<br>"
|
478 |
+
message += "・DallE3を利用する場合はプロンプト、GPT-4 Visionを利用する場合は画像とプロンプトを入力して下さい。<br>"
|
479 |
+
# message += "・テスト中でAPIKEY無しで動きます。<br>"
|
480 |
+
# message += "※現在画像から画像を作るimg2imgはできません。<br>"
|
481 |
+
# message += "・動画での紹介はこちら→https://www.youtube.com/watch?v=<br></h3>"
|
482 |
+
message += "</h3>"
|
483 |
+
|
484 |
+
gr.Markdown(title + message)
|
485 |
+
|
486 |
+
# セッションの宣言
|
487 |
+
state = gr.State({
|
488 |
+
# "system_prompt": SYS_PROMPT_DEFAULT,
|
489 |
+
"openai_key" : "",
|
490 |
+
"size" : "1024x1024",
|
491 |
+
"quality" : "standard",
|
492 |
+
"detail" : "low",
|
493 |
+
"client" : None,
|
494 |
+
"assistant_id" : "",
|
495 |
+
"thread_id" : "",
|
496 |
+
"last_msg_id" : "",
|
497 |
+
"tool_outputs" : []
|
498 |
+
})
|
499 |
+
|
500 |
+
with gr.Tab("Chat画面") as chat:
|
501 |
+
|
502 |
+
# 各コンポーネント定義
|
503 |
+
chatbot = gr.Chatbot(label="チャット画面")
|
504 |
+
text_msg = gr.Textbox(label="プロンプト", placeholder = PLACEHOLDER)
|
505 |
+
text_dummy = gr.Textbox(visible=False)
|
506 |
+
gr.Examples(label="サンプルプロンプト", examples=examples, inputs=text_dummy, outputs=text_msg, fn=make_prompt, cache_examples=True)
|
507 |
+
|
508 |
+
|
509 |
+
with gr.Row():
|
510 |
+
btn = gr.Button(value="送信")
|
511 |
+
btn_dl = gr.Button(value="画像の一括ダウンロード") # 保留中
|
512 |
+
btn_clear = gr.ClearButton(value="リセット", components=[chatbot, text_msg, image ,out_image])
|
513 |
+
|
514 |
+
with gr.Row():
|
515 |
+
image = gr.Image(label="ファイルアップロード", type="filepath",interactive = True)
|
516 |
+
out_image = gr.Image(label="出力画像", type="filepath", interactive = False)
|
517 |
+
|
518 |
+
sys_msg = gr.Textbox(label="システムメッセージ", interactive = False)
|
519 |
+
# out_text = gr.Textbox(label="出力テキスト", lines = 5, interactive = False)
|
520 |
+
out_file = gr.File(label="出力ファイル", type="filepath",interactive = False)
|
521 |
+
|
522 |
+
# 送信ボタンクリック時の処理
|
523 |
+
bc = btn.click(init, [state, text_msg, image], [state, sys_msg], queue=False).success(
|
524 |
+
raise_exception, sys_msg, None).success(
|
525 |
+
add_history, [chatbot, text_msg, image], [chatbot, text_msg, image, sys_msg], queue=False).success(
|
526 |
+
bot, [state, chatbot, image],[chatbot, out_image, out_file, sys_msg]).then(
|
527 |
+
finally_proc, None, [text_msg, image], queue=False
|
528 |
+
)
|
529 |
+
btn_dl.click(make_archive, None, [out_file, sys_msg])
|
530 |
+
# クリア時でもセッションの設定(OpenAIKeyなどは残す)
|
531 |
+
btn_clear.click(clear_click, state, state)
|
532 |
+
|
533 |
+
# テキスト入力Enter時の処理
|
534 |
+
# txt_msg = text_msg.submit(respond, inputs=[text_msg, image, chatbot], outputs=[text_msg, image, chatbot])
|
535 |
+
|
536 |
+
with gr.Tab("設定") as set:
|
537 |
+
|
538 |
+
gr.Markdown("<h4>OpenAI設定</h4>")
|
539 |
+
with gr.Row():
|
540 |
+
openai_key = gr.Textbox(label="OpenAI API Key", visible=True) # テスト中は表示せず
|
541 |
+
# system_prompt = gr.Textbox(value = SYS_PROMPT_DEFAULT,lines = 5, label="Custom instructions", interactive = True)
|
542 |
+
gr.Markdown("<h4>DaLL-E3用設定</h4>")
|
543 |
+
with gr.Row():
|
544 |
+
size = gr.Dropdown(label="サイズ", choices=["1024x1024"], value = "1024x1024", interactive = True)
|
545 |
+
quality = gr.Dropdown(label="クオリティ", choices=["standard"], value = "standard", interactive = True)
|
546 |
+
gr.Markdown("<h4>Vison用設定</h4>")
|
547 |
+
with gr.Row():
|
548 |
+
detail = gr.Dropdown(label="コード出力", choices=["low", "high" , "auto"], value = "low", interactive = True)
|
549 |
+
|
550 |
+
# 設定タブからChatタブに戻った時の処理
|
551 |
+
chat.select(set_state, [openai_key, size, quality, detail, state], state)
|
552 |
+
|
553 |
+
with gr.Tab("注意事項") as notes:
|
554 |
+
caution = "現在Assistant APIはβ版でのリリースとなっています。<br>"
|
555 |
+
caution += "そのためか一部のファイルのアップロードが上手くいかないため、制限をかけています。<br>"
|
556 |
+
caution += "(現在アップできるファイル形式は.txtと.csvのみ)<br>"
|
557 |
+
caution += "本来はPDFなども利用できるはずなので、今後更新したいと思います。また日本語文字化けも調査中です。"
|
558 |
+
|
559 |
+
gr.Markdown("<h3>" + caution + "</h3>")
|
560 |
+
|
561 |
+
|
562 |
+
demo.queue()
|
563 |
+
demo.launch(debug=True)
|
564 |
+
|
565 |
+
|
566 |
+
def request_DallE3(client, prompt, size, quality, out_image_path):
|
567 |
+
|
568 |
+
err_msg = ""
|
569 |
+
|
570 |
+
try:
|
571 |
+
|
572 |
+
response = client.images.generate(
|
573 |
+
model="dall-e-3",
|
574 |
+
prompt=prompt,
|
575 |
+
size=size,
|
576 |
+
quality=quality,
|
577 |
+
n=1,
|
578 |
+
response_format="b64_json"
|
579 |
+
)
|
580 |
+
|
581 |
+
print(response.data[0])
|
582 |
+
|
583 |
+
# データを受け取りデコード
|
584 |
+
image_data_json = response.data[0].b64_json
|
585 |
+
image_data = base64.b64decode(image_data_json)
|
586 |
+
|
587 |
+
# 画像として扱えるように保存
|
588 |
+
image_stream = BytesIO(image_data)
|
589 |
+
image = Image.open(image_stream)
|
590 |
+
image.save(out_image_path)
|
591 |
+
|
592 |
+
except BadRequestError as e:
|
593 |
+
print(e)
|
594 |
+
out_image_path = ""
|
595 |
+
err_msg = "リクエストエラーです。著作権侵害などプロンプトを確認して下さい。"
|
596 |
+
except Exception as e:
|
597 |
+
print(e)
|
598 |
+
out_image_path = ""
|
599 |
+
err_msg = "その他のエラーが発生しました。"
|
600 |
+
|
601 |
+
finally:
|
602 |
+
|
603 |
+
# 結果をJSONで返す
|
604 |
+
dalle3_result = {
|
605 |
+
"image_path" : out_image_path,
|
606 |
+
"error_message" : err_msg
|
607 |
+
}
|
608 |
+
return json.dumps(dalle3_result)
|
609 |
+
|
610 |
+
|
611 |
+
def request_Vision(client, prompt, image_path, detail, max_tokens):
|
612 |
+
|
613 |
+
response_text = ""
|
614 |
+
err_msg = ""
|
615 |
+
|
616 |
+
if image_path == "":
|
617 |
+
|
618 |
+
# 画像がない時はエラーとして返す
|
619 |
+
vision_result = {"answer" : "", "error_message" : "画像をセットして下さい。"}
|
620 |
+
return json.dumps(vision_result)
|
621 |
+
|
622 |
+
try:
|
623 |
+
|
624 |
+
# 画像をbase64に変換
|
625 |
+
image = encode_image(image_path)
|
626 |
+
|
627 |
+
# メッセージの作成
|
628 |
+
messages = [
|
629 |
+
{
|
630 |
+
"role": "user",
|
631 |
+
"content": [
|
632 |
+
{"type": "text", "text": prompt},
|
633 |
+
{
|
634 |
+
"type": "image_url",
|
635 |
+
"image_url": {
|
636 |
+
"url": f"data:image/jpeg;base64,{image}",
|
637 |
+
"detail": detail,
|
638 |
+
}
|
639 |
+
},
|
640 |
+
],
|
641 |
+
}
|
642 |
+
]
|
643 |
+
|
644 |
+
# gpt-4-visionに問い合わせて回答を表示
|
645 |
+
response = client.chat.completions.create(
|
646 |
+
model="gpt-4-vision-preview", # Visionはこのモデル指定
|
647 |
+
messages=messages,
|
648 |
+
max_tokens=max_tokens,
|
649 |
+
)
|
650 |
+
|
651 |
+
response_text = response.choices[0].message.content
|
652 |
+
|
653 |
+
print(response_text)
|
654 |
+
|
655 |
+
except BadRequestError as e:
|
656 |
+
print(e)
|
657 |
+
err_msg = "リクエストエラーです。画像がポリシー違反でないか確認して下さい。"
|
658 |
+
except Exception as e:
|
659 |
+
print(e)
|
660 |
+
err_msg = "その他のエラーが発生しました。"
|
661 |
+
|
662 |
+
finally:
|
663 |
+
|
664 |
+
# 結果をJSONで返す
|
665 |
+
vision_result = {
|
666 |
+
"answer" : response_text,
|
667 |
+
"error_message" : err_msg
|
668 |
+
}
|
669 |
+
return json.dumps(vision_result)
|
app.py
CHANGED
@@ -454,7 +454,8 @@ def clear_click(state):
|
|
454 |
state["last_msg_id"] = ""
|
455 |
state["tool_outputs"] = []
|
456 |
|
457 |
-
|
|
|
458 |
|
459 |
def make_archive():
|
460 |
|
@@ -505,16 +506,14 @@ with gr.Blocks() as demo:
|
|
505 |
text_dummy = gr.Textbox(visible=False)
|
506 |
gr.Examples(label="サンプルプロンプト", examples=examples, inputs=text_dummy, outputs=text_msg, fn=make_prompt, cache_examples=True)
|
507 |
|
508 |
-
|
|
|
|
|
509 |
with gr.Row():
|
510 |
btn = gr.Button(value="送信")
|
511 |
btn_dl = gr.Button(value="画像の一括ダウンロード") # 保留中
|
512 |
-
btn_clear = gr.ClearButton(value="リセット", components=[chatbot, text_msg
|
513 |
|
514 |
-
with gr.Row():
|
515 |
-
image = gr.Image(label="ファイルアップロード", type="filepath",interactive = True)
|
516 |
-
out_image = gr.Image(label="出力画像", type="filepath", interactive = False)
|
517 |
-
|
518 |
sys_msg = gr.Textbox(label="システムメッセージ", interactive = False)
|
519 |
# out_text = gr.Textbox(label="出力テキスト", lines = 5, interactive = False)
|
520 |
out_file = gr.File(label="出力ファイル", type="filepath",interactive = False)
|
@@ -527,8 +526,8 @@ with gr.Blocks() as demo:
|
|
527 |
finally_proc, None, [text_msg, image], queue=False
|
528 |
)
|
529 |
btn_dl.click(make_archive, None, [out_file, sys_msg])
|
530 |
-
#
|
531 |
-
btn_clear.click(clear_click, state, state)
|
532 |
|
533 |
# テキスト入力Enter時の処理
|
534 |
# txt_msg = text_msg.submit(respond, inputs=[text_msg, image, chatbot], outputs=[text_msg, image, chatbot])
|
|
|
454 |
state["last_msg_id"] = ""
|
455 |
state["tool_outputs"] = []
|
456 |
|
457 |
+
# 順番的にgr.ClearButtonで消せないImageなども初期化
|
458 |
+
return state, None, None
|
459 |
|
460 |
def make_archive():
|
461 |
|
|
|
506 |
text_dummy = gr.Textbox(visible=False)
|
507 |
gr.Examples(label="サンプルプロンプト", examples=examples, inputs=text_dummy, outputs=text_msg, fn=make_prompt, cache_examples=True)
|
508 |
|
509 |
+
with gr.Row():
|
510 |
+
image = gr.Image(label="ファイルアップロード", type="filepath",interactive = True)
|
511 |
+
out_image = gr.Image(label="出力画像", type="filepath", interactive = False)
|
512 |
with gr.Row():
|
513 |
btn = gr.Button(value="送信")
|
514 |
btn_dl = gr.Button(value="画像の一括ダウンロード") # 保留中
|
515 |
+
btn_clear = gr.ClearButton(value="リセット", components=[chatbot, text_msg])
|
516 |
|
|
|
|
|
|
|
|
|
517 |
sys_msg = gr.Textbox(label="システムメッセージ", interactive = False)
|
518 |
# out_text = gr.Textbox(label="出力テキスト", lines = 5, interactive = False)
|
519 |
out_file = gr.File(label="出力ファイル", type="filepath",interactive = False)
|
|
|
526 |
finally_proc, None, [text_msg, image], queue=False
|
527 |
)
|
528 |
btn_dl.click(make_archive, None, [out_file, sys_msg])
|
529 |
+
# クリア時でもセッションの一部は残す(OpenAIKeyなど)
|
530 |
+
btn_clear.click(clear_click, state, [state, image ,out_image])
|
531 |
|
532 |
# テキスト入力Enter時の処理
|
533 |
# txt_msg = text_msg.submit(respond, inputs=[text_msg, image, chatbot], outputs=[text_msg, image, chatbot])
|
gradio_chat_image.py → gradio_chat_image - コピー.py
RENAMED
@@ -474,7 +474,8 @@ def clear_click(state):
|
|
474 |
state["last_msg_id"] = ""
|
475 |
state["tool_outputs"] = []
|
476 |
|
477 |
-
|
|
|
478 |
|
479 |
def make_archive():
|
480 |
|
@@ -531,7 +532,7 @@ with gr.Blocks() as demo:
|
|
531 |
with gr.Row():
|
532 |
btn = gr.Button(value="送信")
|
533 |
btn_dl = gr.Button(value="画像の一括ダウンロード") # 保留中
|
534 |
-
btn_clear = gr.ClearButton(value="リセット", components=[chatbot, text_msg
|
535 |
|
536 |
sys_msg = gr.Textbox(label="システムメッセージ", interactive = False)
|
537 |
# out_text = gr.Textbox(label="出力テキスト", lines = 5, interactive = False)
|
@@ -545,8 +546,8 @@ with gr.Blocks() as demo:
|
|
545 |
finally_proc, None, [text_msg, image], queue=False
|
546 |
)
|
547 |
btn_dl.click(make_archive, None, [out_file, sys_msg])
|
548 |
-
#
|
549 |
-
btn_clear.click(clear_click, state, state)
|
550 |
|
551 |
# テキスト入力Enter時の処理
|
552 |
# txt_msg = text_msg.submit(respond, inputs=[text_msg, image, chatbot], outputs=[text_msg, image, chatbot])
|
|
|
474 |
state["last_msg_id"] = ""
|
475 |
state["tool_outputs"] = []
|
476 |
|
477 |
+
# 順番的にgr.ClearButtonで消せないImageなども初期化
|
478 |
+
return state, None, None
|
479 |
|
480 |
def make_archive():
|
481 |
|
|
|
532 |
with gr.Row():
|
533 |
btn = gr.Button(value="送信")
|
534 |
btn_dl = gr.Button(value="画像の一括ダウンロード") # 保留中
|
535 |
+
btn_clear = gr.ClearButton(value="リセット", components=[chatbot, text_msg])
|
536 |
|
537 |
sys_msg = gr.Textbox(label="システムメッセージ", interactive = False)
|
538 |
# out_text = gr.Textbox(label="出力テキスト", lines = 5, interactive = False)
|
|
|
546 |
finally_proc, None, [text_msg, image], queue=False
|
547 |
)
|
548 |
btn_dl.click(make_archive, None, [out_file, sys_msg])
|
549 |
+
# クリア時でもセッションの一部は残す(OpenAIKeyなど)
|
550 |
+
btn_clear.click(clear_click, state, [state, image ,out_image])
|
551 |
|
552 |
# テキスト入力Enter時の処理
|
553 |
# txt_msg = text_msg.submit(respond, inputs=[text_msg, image, chatbot], outputs=[text_msg, image, chatbot])
|