nekoniii3 commited on
Commit
2803f09
·
1 Parent(s): fb44093

new create

Browse files
Files changed (4) hide show
  1. app.py +669 -0
  2. gradio_chat_image.py +712 -0
  3. requirements.txt +2 -0
  4. sample_data/dummy.txt +1 -0
app.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import base64
4
+ import re
5
+ import time
6
+ import datetime
7
+ from zoneinfo import ZoneInfo
8
+ from PIL import Image
9
+ from io import BytesIO
10
+ import shutil
11
+ import gradio as gr
12
+ from openai import (
13
+ OpenAI, AuthenticationError, NotFoundError, BadRequestError
14
+ )
15
+
16
+
17
+ # GPT用設定
18
+ SYS_PROMPT_DEFAULT = "あなたは優秀なアシスタントです。回答は日本語でお願いします。"
19
+ DUMMY = "********************"
20
+ file_format = {".png", ".jpeg", ".jpg", ".webp", ".gif", ".PNG", ".JPEG", ".JPG", ".WEBP", ".GIF"}
21
+
22
+ # 各種出力フォルダ
23
+ IMG_FOLDER = "sample_data" #"images"
24
+
25
+ # 各種メッセージ
26
+ PLACEHOLDER = ""
27
+ # IMG_MSG = "(画像ファイルを追加しました。リセットボタンの上に表示されています。)"
28
+ ANT_MSG = "(下部の[出力ファイル]にファイルを追加しました。)"
29
+
30
+ # 各種設定値
31
+ MAX_TRIAL = int(os.environ["MAX_TRIAL"]) # メッセージ取得最大試行数
32
+ INTER_SEC = int(os.environ["INTER_SEC"]) # 試行間隔(秒)
33
+ MAX_TOKENS = int(os.environ["MAX_TOKENS"]) # Vison最大トークン
34
+
35
+ # 正規表現用パターン
36
+ pt = r".*!\[(.*)\]\((.*)\)"
37
+
38
+ # サンプル用情報
39
+ examples = ["1980s anime girl with straight bob-cut in school uniform, roughly drawn drawing"
40
+ , "a minimalisit logo for a sporting goods company"]
41
+
42
+
43
+ # 各関数定義
44
+ def set_state(openai_key, size, quality, detail, state):
45
+ """ 設定タブの情報をセッションに保存する関数 """
46
+
47
+ state["openai_key"] = openai_key
48
+ state["size"] = size
49
+ state["quality"] = quality
50
+ state["detail"] = detail
51
+
52
+ return state
53
+
54
+
55
+ def init(state, text, image):
56
+ """ 入力チェックを行う関数 """
57
+ """ ※ここで例外を起こすと入力できなくなるので次の関数でエラーにする """
58
+
59
+ err_msg = ""
60
+
61
+ print(state)
62
+
63
+ # if state["openai_key"] == "" or state["openai_key"] is None:
64
+
65
+ # # OpenAI API Key未入力
66
+ # err_msg = "OpenAI API Keyを入力してください。(設定タブ)"
67
+
68
+ if not text:
69
+
70
+ # テキスト未入力
71
+ err_msg = "テキストを入力して下さい。"
72
+
73
+ return state, err_msg
74
+
75
+ elif image:
76
+
77
+ # 入力画像のファイル形式チェック
78
+ root, ext = os.path.splitext(image)
79
+
80
+ print(ext, file_format)
81
+
82
+ if ext not in file_format:
83
+
84
+ # ファイル形式チェック
85
+ err_msg = "指定した形式のファイルをアップしてください。(注意事項タブに記載)"
86
+
87
+ return state, err_msg
88
+
89
+ try:
90
+
91
+ if state["client"] is None:
92
+
93
+ # 初回起動時は初期処理をする
94
+ os.environ["OPENAI_API_KEY"] = os.environ["TEST_OPENAI_KEY"] # テスト時
95
+ # os.environ["OPENAI_API_KEY"] = state["openai_key"]
96
+
97
+ # クライアント新規作成
98
+ client = OpenAI()
99
+
100
+ # client作成後は消す
101
+ os.environ["OPENAI_API_KEY"] = ""
102
+
103
+ # セッションにセット
104
+ state["client"] = client
105
+
106
+ else:
107
+
108
+ # 既存のクライアントをセット
109
+ client = state["client"]
110
+
111
+
112
+ if state["thread_id"] == "":
113
+
114
+ # スレッド作成
115
+ thread = client.beta.threads.create()
116
+
117
+ state["thread_id"] = thread.id
118
+
119
+
120
+ if state["assistant_id"] == "":
121
+
122
+ # アシスタント作成
123
+ # assistant = client.beta.assistants.create(
124
+ # name="codeinter_test",
125
+ # instructions=state["system_prompt"],
126
+ # # model="gpt-4-1106-preview",
127
+ # model="gpt-3.5-turbo-1106",
128
+ # tools=[{"type": "code_interpreter"}]
129
+ # )
130
+ # state["assistant_id"] = assistant.id
131
+
132
+ state["assistant_id"] = os.environ["ASSIST_ID"] # テスト中アシスタントは固定
133
+
134
+ else:
135
+
136
+ # アシスタント確認(IDが存在しないならエラーとなる)
137
+ assistant = client.beta.assistants.retrieve(state["assistant_id"])
138
+
139
+ except NotFoundError as e:
140
+ err_msg = "アシスタントIDが間違っています。新しく作成する場合はアシスタントIDを空欄にして下さい。"
141
+ except AuthenticationError as e:
142
+ err_msg = "認証エラーとなりました。OpenAPIKeyが正しいか、支払い方法などが設定されているか確認して下さい。"
143
+ except Exception as e:
144
+ err_msg = "その他のエラーが発生しました。"
145
+ print(e)
146
+ finally:
147
+ return state, err_msg
148
+
149
+
150
+ def raise_exception(err_msg):
151
+ """ エラーの場合例外を起こす関数 """
152
+
153
+ if err_msg != "":
154
+ raise Exception("これは入力チェックでの例外です。")
155
+
156
+ return
157
+
158
+
159
+ def add_history(history, text, image):
160
+ """ Chat履歴"history"に追加を行う関数 """
161
+
162
+ err_msg = ""
163
+
164
+ if image is None or image == "":
165
+
166
+ # テキストだけの場合そのまま追加
167
+ history = history + [(text, None)]
168
+
169
+ elif image is not None:
170
+
171
+ # 画像があれば画像とテキストを追加
172
+ history = history + [((image,), DUMMY)]
173
+ history = history + [(text, None)]
174
+
175
+ # テキストは利用不可・初期化し、画像は利用不可に
176
+ update_text = gr.update(value="", placeholder = "",interactive=False)
177
+ update_file = gr.update(interactive=False)
178
+
179
+ return history, update_text, update_file, err_msg
180
+
181
+
182
+ def bot(state, history, image_path):
183
+
184
+ err_msg = ""
185
+ out_image_path = None
186
+ image_preview = False
187
+
188
+ ant_file = None
189
+
190
+ # セッション情報取得
191
+ client = state["client"]
192
+ assistant_id = state["assistant_id"]
193
+ thread_id = state["thread_id"]
194
+ last_msg_id = state["last_msg_id"]
195
+ history_outputs = state["tool_outputs"]
196
+
197
+ # メッセージ設定
198
+ message = client.beta.threads.messages.create(
199
+ thread_id=thread_id,
200
+ role="user",
201
+ content=history[-1][0],
202
+ )
203
+
204
+ # RUNスタート
205
+ run = client.beta.threads.runs.create(
206
+ thread_id=thread_id,
207
+ assistant_id=assistant_id,
208
+ # instructions=system_prompt
209
+ )
210
+
211
+ # "completed"となるまで繰り返す(指定秒おき)
212
+ for i in range(0, MAX_TRIAL, 1):
213
+
214
+ if i > 0:
215
+ time.sleep(INTER_SEC)
216
+
217
+ # 変数初期化
218
+ tool_outputs = []
219
+
220
+ # メッセージ受け取り
221
+ run = client.beta.threads.runs.retrieve(
222
+ thread_id=thread_id,
223
+ run_id=run.id
224
+ )
225
+
226
+ print(run.status)
227
+
228
+ if run.status == "requires_action": # 関数の結果の待ちの場合
229
+
230
+ print(run.required_action)
231
+
232
+ # tool_callsの各項目取得
233
+ tool_calls = run.required_action.submit_tool_outputs.tool_calls
234
+
235
+ print(len(tool_calls))
236
+ print(tool_calls)
237
+
238
+ # 一つ目だけ取得
239
+ tool_id = tool_calls[0].id
240
+ func_name = tool_calls[0].function.name
241
+ func_args = json.loads(tool_calls[0].function.arguments)
242
+
243
+ if func_name == "request_DallE3":
244
+
245
+ # ファイル名は現在時刻に
246
+ dt = datetime.datetime.now(ZoneInfo("Asia/Tokyo"))
247
+ image_name = dt.strftime("%Y%m%d%H%M%S") + ".png"
248
+
249
+ # ファイルパスは手動設定(誤りがないように)
250
+ out_image_path = IMG_FOLDER + "/" + image_name
251
+
252
+ # dall-e3のとき"image_path"は出力ファイルパス
253
+ func_args["image_path"] = out_image_path
254
+
255
+ elif func_name == "request_Vision":
256
+
257
+ if image_path is None:
258
+
259
+ # 画像がない場合エラーとなるようにする
260
+ func_args["image_path"] = ""
261
+
262
+ else:
263
+
264
+ # ファイルパスは手動設定
265
+ func_args["image_path"] = image_path
266
+
267
+ else:
268
+
269
+ # 関数名がないなら次へ
270
+ continue
271
+
272
+ # 関数を実行
273
+ func_output = func_action(state, func_name, func_args)
274
+
275
+ print(func_output)
276
+
277
+ # tool_outputリストに追加
278
+ tool_outputs.append({"tool_call_id": tool_id, "output": func_output})
279
+
280
+ # 複数の関数が必要な場合
281
+ if len(tool_calls) > 1:
282
+
283
+ # for i in range(len(tool_calls) - 1):
284
+ for i, tool_call in enumerate(tool_calls):
285
+
286
+ if i > 0:
287
+ # print(history_outputs[-(i+1)])
288
+ # # 最新のものからセット
289
+ # tool_outputs.append({"tool_call_id": history_outputs[-(i+1)]["tool_call_id"], "output": history_outputs[-(i+1)]["output"]})
290
+
291
+ # ダミー をセットする
292
+ tool_outputs.append({"tool_call_id": tool_calls.id, "output": {"answer" : ""}})
293
+
294
+ print(tool_outputs)
295
+
296
+ # 関数の出力を提出
297
+ run = client.beta.threads.runs.submit_tool_outputs(
298
+ thread_id=thread_id,
299
+ run_id=run.id,
300
+ tool_outputs=tool_outputs
301
+ )
302
+
303
+ if func_name == "request_DallE3":
304
+
305
+ # 画像の表示をする
306
+ image_preview = True
307
+
308
+ # セッション更新
309
+ history_outputs += tool_outputs
310
+ state["tool_outputs"] = history_outputs
311
+
312
+ else:
313
+
314
+ # 前回のメッセージより後を昇順で取り出す
315
+ messages = client.beta.threads.messages.list(
316
+ thread_id=thread_id,
317
+ after=last_msg_id,
318
+ order="asc"
319
+ )
320
+
321
+ print(messages)
322
+
323
+ # messageを取り出す
324
+ for msg in messages:
325
+
326
+ if msg.role == "assistant":
327
+
328
+ for content in msg.content:
329
+
330
+ res_text = ""
331
+ file_id = ""
332
+
333
+ cont_dict = content.model_dump() # 辞書型に変換
334
+
335
+ # 返答テキスト取得
336
+ res_text = cont_dict["text"].get("value")
337
+
338
+ if res_text != "":
339
+
340
+ # テキストを変換("sandbox:"などを」消す)
341
+ result = re.search(pt, res_text)
342
+
343
+ if result:
344
+
345
+ # パターン一致の場合はプロンプトだけ抜き出す
346
+ res_text = result.group(1)
347
+
348
+ # Chat画面更新
349
+ if history[-1][1] is not None:
350
+
351
+ # 新しい行を追加
352
+ history = history + [[None, res_text]]
353
+
354
+ else:
355
+
356
+ history[-1][1] = res_text
357
+
358
+ if image_preview:
359
+
360
+ print(out_image_path)
361
+
362
+ # Functionで画像を取得していた場合表示
363
+ history = history + [(None, (out_image_path,))]
364
+
365
+ image_preview = False
366
+
367
+ # 最終メッセージID更新
368
+ last_msg_id = msg.id
369
+
370
+ # Chatbotを返す(labelとhistoryを更新)
371
+ yield gr.Chatbot(label=run.status ,value=history), out_image_path, ant_file, err_msg
372
+
373
+ # セッションのメッセージID更新
374
+ state["last_msg_id"] = last_msg_id
375
+
376
+ # 完了なら終了
377
+ if run.status == "completed":
378
+
379
+ yield gr.Chatbot(label=run.status ,value=history), out_image_path, ant_file, err_msg
380
+ break
381
+
382
+
383
+ elif run.status == "failed":
384
+
385
+ # エラーとして終了
386
+ err_msg = "※メッセージ取得に失敗しました。"
387
+ yield gr.Chatbot(label=run.status ,value=history), out_image_path, ant_file, err_msg
388
+ break
389
+
390
+ elif i == MAX_TRIAL:
391
+
392
+ # エラーとして終了
393
+ err_msg = "※メッセージ取得の際にタイムアウトしました。"
394
+ yield gr.Chatbot(label=run.status ,value=history), out_image_path, ant_file, err_msg
395
+ break
396
+
397
+ else:
398
+ if i > 3:
399
+
400
+ # 作業中とわかるようにする
401
+ yield gr.Chatbot(label=run.status + " (Request:" + str(i) + ")" ,value=history), out_image_path, ant_file, err_msg
402
+
403
+
404
+ def func_action(state, func_name, func_args):
405
+
406
+ # セッションから情報取得
407
+ client = state["client"]
408
+ size = state["size"]
409
+ quality = state["quality"]
410
+ detail = state["detail"]
411
+
412
+ print("name:", func_name)
413
+ print("arguments:", func_args)
414
+
415
+ if func_name == "request_DallE3":
416
+
417
+ func_output = request_DallE3(
418
+ client,
419
+ func_args["prompt"],
420
+ size,
421
+ quality,
422
+ func_args["image_path"] # 出力パス
423
+ )
424
+
425
+ elif func_name == "request_Vision":
426
+
427
+ func_output = request_Vision(
428
+ client,
429
+ func_args["prompt"],
430
+ func_args["image_path"],
431
+ detail,
432
+ MAX_TOKENS
433
+ )
434
+
435
+ return func_output
436
+
437
+ def finally_proc():
438
+ """ 最終処理用関数 """
439
+
440
+ # テキストを使えるように
441
+ new_text = gr.update(interactive = True)
442
+
443
+ # 画像はリセット
444
+ new_image = gr.update(value=None, interactive = True)
445
+
446
+ return new_text, new_image
447
+
448
+
449
+ def clear_click(state):
450
+ """ クリアボタンクリック時 """
451
+
452
+ # セッションの一部をリセット
453
+ state["thread_id"] = ""
454
+ state["last_msg_id"] = ""
455
+ state["tool_outputs"] = []
456
+
457
+ return state
458
+
459
+ def make_archive():
460
+
461
+ shutil.make_archive("output_image", format='zip', root_dir=IMG_FOLDER)
462
+
463
+ return "output_image.zip", "下部の出力ファイルからダウンロードして下さい。"
464
+
465
+ def encode_image(image_path):
466
+ with open(image_path, "rb") as image_file:
467
+ return base64.b64encode(image_file.read()).decode("utf-8")
468
+
469
+ def make_prompt(prompt):
470
+
471
+ return "次のプロンプトで画像を作ってください「" + prompt + "」。"
472
+
473
+ # 画面構成
474
+ with gr.Blocks() as demo:
475
+
476
+ title = "<h2>GPT画像入出力対応チャット</h2>"
477
+ message = "<h3>・DallE3の画像生成とGPT-4 with Visionの画像解析が利用できます。<br>"
478
+ message += "・DallE3を利用する場合はプロンプト、GPT-4 Visionを利用する場合は画像とプロンプトを入力して下さい。<br>"
479
+ # message += "・テスト中でAPIKEY無しで動きます。<br>"
480
+ # message += "※現在画像から画像を作るimg2imgはできません。<br>"
481
+ # message += "・動画での紹介はこちら→https://www.youtube.com/watch?v=<br></h3>"
482
+ message += "</h3>"
483
+
484
+ gr.Markdown(title + message)
485
+
486
+ # セッションの宣言
487
+ state = gr.State({
488
+ # "system_prompt": SYS_PROMPT_DEFAULT,
489
+ "openai_key" : "",
490
+ "size" : "1024x1024",
491
+ "quality" : "standard",
492
+ "detail" : "low",
493
+ "client" : None,
494
+ "assistant_id" : "",
495
+ "thread_id" : "",
496
+ "last_msg_id" : "",
497
+ "tool_outputs" : []
498
+ })
499
+
500
+ with gr.Tab("Chat画面") as chat:
501
+
502
+ # 各コンポーネント定義
503
+ chatbot = gr.Chatbot(label="チャット画面")
504
+ text_msg = gr.Textbox(label="プロンプト", placeholder = PLACEHOLDER)
505
+ text_dummy = gr.Textbox(visible=False)
506
+ gr.Examples(label="サンプルプロンプト", examples=examples, inputs=text_dummy, outputs=text_msg, fn=make_prompt, cache_examples=True)
507
+
508
+
509
+ with gr.Row():
510
+ btn = gr.Button(value="送信")
511
+ btn_dl = gr.Button(value="画像の一括ダウンロード") # 保留中
512
+ btn_clear = gr.ClearButton(value="リセット", components=[chatbot, text_msg, image ,out_image])
513
+
514
+ with gr.Row():
515
+ image = gr.Image(label="ファイルアップロード", type="filepath",interactive = True)
516
+ out_image = gr.Image(label="出力画像", type="filepath", interactive = False)
517
+
518
+ sys_msg = gr.Textbox(label="システムメッセージ", interactive = False)
519
+ # out_text = gr.Textbox(label="出力テキスト", lines = 5, interactive = False)
520
+ out_file = gr.File(label="出力ファイル", type="filepath",interactive = False)
521
+
522
+ # 送信ボタンクリック時の処理
523
+ bc = btn.click(init, [state, text_msg, image], [state, sys_msg], queue=False).success(
524
+ raise_exception, sys_msg, None).success(
525
+ add_history, [chatbot, text_msg, image], [chatbot, text_msg, image, sys_msg], queue=False).success(
526
+ bot, [state, chatbot, image],[chatbot, out_image, out_file, sys_msg]).then(
527
+ finally_proc, None, [text_msg, image], queue=False
528
+ )
529
+ btn_dl.click(make_archive, None, [out_file, sys_msg])
530
+ # クリア時でもセッションの設定(OpenAIKeyなどは残す)
531
+ btn_clear.click(clear_click, state, state)
532
+
533
+ # テキスト入力Enter時の処理
534
+ # txt_msg = text_msg.submit(respond, inputs=[text_msg, image, chatbot], outputs=[text_msg, image, chatbot])
535
+
536
+ with gr.Tab("設定") as set:
537
+
538
+ gr.Markdown("<h4>OpenAI設定</h4>")
539
+ with gr.Row():
540
+ openai_key = gr.Textbox(label="OpenAI API Key", visible=True) # テスト中は表示せず
541
+ # system_prompt = gr.Textbox(value = SYS_PROMPT_DEFAULT,lines = 5, label="Custom instructions", interactive = True)
542
+ gr.Markdown("<h4>DaLL-E3用設定</h4>")
543
+ with gr.Row():
544
+ size = gr.Dropdown(label="サイズ", choices=["1024x1024"], value = "1024x1024", interactive = True)
545
+ quality = gr.Dropdown(label="クオリティ", choices=["standard"], value = "standard", interactive = True)
546
+ gr.Markdown("<h4>Vison用設定</h4>")
547
+ with gr.Row():
548
+ detail = gr.Dropdown(label="コード出力", choices=["low", "high" , "auto"], value = "low", interactive = True)
549
+
550
+ # 設定タブからChatタブに戻った時の処理
551
+ chat.select(set_state, [openai_key, size, quality, detail, state], state)
552
+
553
+ with gr.Tab("注意事項") as notes:
554
+ caution = "現在Assistant APIはβ版でのリリースとなっています。<br>"
555
+ caution += "そのためか一部のファイルのアップロードが上手くいかないため、制限をかけています。<br>"
556
+ caution += "(現在アップできるファイル形式は.txtと.csvのみ)<br>"
557
+ caution += "本来はPDFなども利用できるはずなので、今後更新したいと思います。また日本語文字化けも調査中です。"
558
+
559
+ gr.Markdown("<h3>" + caution + "</h3>")
560
+
561
+
562
+ demo.queue()
563
+ demo.launch(debug=True)
564
+
565
+
566
+ def request_DallE3(client, prompt, size, quality, out_image_path):
567
+
568
+ err_msg = ""
569
+
570
+ try:
571
+
572
+ response = client.images.generate(
573
+ model="dall-e-3",
574
+ prompt=prompt,
575
+ size=size,
576
+ quality=quality,
577
+ n=1,
578
+ response_format="b64_json"
579
+ )
580
+
581
+ print(response.data[0])
582
+
583
+ # データを受け取りデコード
584
+ image_data_json = response.data[0].b64_json
585
+ image_data = base64.b64decode(image_data_json)
586
+
587
+ # 画像として扱えるように保存
588
+ image_stream = BytesIO(image_data)
589
+ image = Image.open(image_stream)
590
+ image.save(out_image_path)
591
+
592
+ except BadRequestError as e:
593
+ print(e)
594
+ out_image_path = ""
595
+ err_msg = "リクエストエラーです。著作権侵害などプロンプトを確認して下さい。"
596
+ except Exception as e:
597
+ print(e)
598
+ out_image_path = ""
599
+ err_msg = "その他のエラーが発生しました。"
600
+
601
+ finally:
602
+
603
+ # 結果をJSONで返す
604
+ dalle3_result = {
605
+ "image_path" : out_image_path,
606
+ "error_message" : err_msg
607
+ }
608
+ return json.dumps(dalle3_result)
609
+
610
+
611
+ def request_Vision(client, prompt, image_path, detail, max_tokens):
612
+
613
+ response_text = ""
614
+ err_msg = ""
615
+
616
+ if image_path == "":
617
+
618
+ # 画像がない時はエラーとして返す
619
+ vision_result = {"answer" : "", "error_message" : "画像をセットして下さい。"}
620
+ return json.dumps(vision_result)
621
+
622
+ try:
623
+
624
+ # 画像をbase64に変換
625
+ image = encode_image(image_path)
626
+
627
+ # メッセージの作成
628
+ messages = [
629
+ {
630
+ "role": "user",
631
+ "content": [
632
+ {"type": "text", "text": prompt},
633
+ {
634
+ "type": "image_url",
635
+ "image_url": {
636
+ "url": f"data:image/jpeg;base64,{image}",
637
+ "detail": detail,
638
+ }
639
+ },
640
+ ],
641
+ }
642
+ ]
643
+
644
+ # gpt-4-visionに問い合わせて回答を表示
645
+ response = client.chat.completions.create(
646
+ model="gpt-4-vision-preview", # Visionはこのモデル指定
647
+ messages=messages,
648
+ max_tokens=max_tokens,
649
+ )
650
+
651
+ response_text = response.choices[0].message.content
652
+
653
+ print(response_text)
654
+
655
+ except BadRequestError as e:
656
+ print(e)
657
+ err_msg = "リクエストエラーです。画像がポリシー違反でないか確認して下さい。"
658
+ except Exception as e:
659
+ print(e)
660
+ err_msg = "その他のエラーが発生しました。"
661
+
662
+ finally:
663
+
664
+ # 結果をJSONで返す
665
+ vision_result = {
666
+ "answer" : response_text,
667
+ "error_message" : err_msg
668
+ }
669
+ return json.dumps(vision_result)
gradio_chat_image.py ADDED
@@ -0,0 +1,712 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """gradio_chat_image.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1UOcwXwg1bHrPbhkM9tB5ivx6OnI9lB7q
8
+ """
9
+
10
+ !pip install gradio
11
+ !pip install openai
12
+
13
+ import os
14
+
15
+ os.environ["TEST_OPENAI_KEY"] = "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"
16
+ os.environ["ASSIST_ID"] = "asst_ePkuRLY0DUmChyCv8tXtf0gC"
17
+ os.environ["MAX_TRIAL"] = "50"
18
+ os.environ["INTER_SEC"] = "1" # 画像は1
19
+ os.environ["MAX_TOKENS"] = "300"
20
+
21
+ import os
22
+ import json
23
+ import base64
24
+ import re
25
+ import time
26
+ import datetime
27
+ from zoneinfo import ZoneInfo
28
+ from PIL import Image
29
+ from io import BytesIO
30
+ import shutil
31
+ import gradio as gr
32
+ from openai import (
33
+ OpenAI, AuthenticationError, NotFoundError, BadRequestError
34
+ )
35
+
36
+
37
+ # GPT用設定
38
+ SYS_PROMPT_DEFAULT = "あなたは優秀なアシスタントです。回答は日本語でお願いします。"
39
+ DUMMY = "********************"
40
+ file_format = {".png", ".jpeg", ".jpg", ".webp", ".gif", ".PNG", ".JPEG", ".JPG", ".WEBP", ".GIF"}
41
+
42
+ # 各種出力フォルダ
43
+ IMG_FOLDER = "sample_data" #"images"
44
+
45
+ # 各種メッセージ
46
+ PLACEHOLDER = ""
47
+ # IMG_MSG = "(画像ファイルを追加しました。リセットボタンの上に表示されています。)"
48
+ ANT_MSG = "(下部の[出力ファイル]にファイルを追加しました。)"
49
+
50
+ # 各種設定値
51
+ MAX_TRIAL = int(os.environ["MAX_TRIAL"]) # メッセージ取得最大試行数
52
+ INTER_SEC = int(os.environ["INTER_SEC"]) # 試行間隔(秒)
53
+ MAX_TOKENS = int(os.environ["MAX_TOKENS"]) # Vison最大トークン
54
+
55
+ # 正規表現用パターン
56
+ pt = r".*!\[(.*)\]\((.*)\)"
57
+
58
+ # サンプル用情報
59
+ examples = ["1980s anime girl with straight bob-cut in school uniform, roughly drawn drawing"
60
+ , "a minimalisit logo for a sporting goods company"]
61
+
62
+
63
+ # 各関数定義
64
+ def set_state(openai_key, size, quality, detail, state):
65
+ """ 設定タブの情報をセッションに保存する関数 """
66
+
67
+ state["openai_key"] = openai_key
68
+ state["size"] = size
69
+ state["quality"] = quality
70
+ state["detail"] = detail
71
+
72
+ return state
73
+
74
+
75
+ def init(state, text, image):
76
+ """ 入力チェックを行う関数 """
77
+ """ ※ここで例外を起こすと入力できなくなるので次の関数でエラーにする """
78
+
79
+ err_msg = ""
80
+
81
+ print(state)
82
+
83
+ # if state["openai_key"] == "" or state["openai_key"] is None:
84
+
85
+ # # OpenAI API Key未入力
86
+ # err_msg = "OpenAI API Keyを入力してください。(設定タブ)"
87
+
88
+ if not text:
89
+
90
+ # テキスト未入力
91
+ err_msg = "テキストを入力して下さい。"
92
+
93
+ return state, err_msg
94
+
95
+ elif image:
96
+
97
+ # 入力画像のファイル形式チェック
98
+ root, ext = os.path.splitext(image)
99
+
100
+ print(ext, file_format)
101
+
102
+ if ext not in file_format:
103
+
104
+ # ファイル形式チェック
105
+ err_msg = "指定した形式のファイルをアップしてください。(注意事項タブに記載)"
106
+
107
+ return state, err_msg
108
+
109
+ try:
110
+
111
+ if state["client"] is None:
112
+
113
+ # 初回起動時は初期処理をする
114
+ os.environ["OPENAI_API_KEY"] = os.environ["TEST_OPENAI_KEY"] # テスト時
115
+ # os.environ["OPENAI_API_KEY"] = state["openai_key"]
116
+
117
+ # クライアント新規作成
118
+ client = OpenAI()
119
+
120
+ # client作成後は消す
121
+ os.environ["OPENAI_API_KEY"] = ""
122
+
123
+ # セッションにセット
124
+ state["client"] = client
125
+
126
+ else:
127
+
128
+ # 既存のクライアントをセット
129
+ client = state["client"]
130
+
131
+
132
+ if state["thread_id"] == "":
133
+
134
+ # スレッド作成
135
+ thread = client.beta.threads.create()
136
+
137
+ state["thread_id"] = thread.id
138
+
139
+
140
+ if state["assistant_id"] == "":
141
+
142
+ # アシスタント作成
143
+ # assistant = client.beta.assistants.create(
144
+ # name="codeinter_test",
145
+ # instructions=state["system_prompt"],
146
+ # # model="gpt-4-1106-preview",
147
+ # model="gpt-3.5-turbo-1106",
148
+ # tools=[{"type": "code_interpreter"}]
149
+ # )
150
+ # state["assistant_id"] = assistant.id
151
+
152
+ state["assistant_id"] = os.environ["ASSIST_ID"] # テスト中アシスタントは固定
153
+
154
+ else:
155
+
156
+ # アシスタント確認(IDが存在しないならエラーとなる)
157
+ assistant = client.beta.assistants.retrieve(state["assistant_id"])
158
+
159
+ except NotFoundError as e:
160
+ err_msg = "アシスタントIDが間違っています。新しく作成する場合はアシスタントIDを空欄にして下さい。"
161
+ except AuthenticationError as e:
162
+ err_msg = "認証エラーとなりました。OpenAPIKeyが正しいか、支払い方法などが設定されているか確認して下さ���。"
163
+ except Exception as e:
164
+ err_msg = "その他のエラーが発生しました。"
165
+ print(e)
166
+ finally:
167
+ return state, err_msg
168
+
169
+
170
+ def raise_exception(err_msg):
171
+ """ エラーの場合例外を起こす関数 """
172
+
173
+ if err_msg != "":
174
+ raise Exception("これは入力チェックでの例外です。")
175
+
176
+ return
177
+
178
+
179
+ def add_history(history, text, image):
180
+ """ Chat履歴"history"に追加を行う関数 """
181
+
182
+ err_msg = ""
183
+
184
+ if image is None or image == "":
185
+
186
+ # テキストだけの場合そのまま追加
187
+ history = history + [(text, None)]
188
+
189
+ elif image is not None:
190
+
191
+ # 画像があれば画像とテキストを追加
192
+ history = history + [((image,), DUMMY)]
193
+ history = history + [(text, None)]
194
+
195
+ # テキストは利用不可・初期化し、画像は利用不可に
196
+ update_text = gr.update(value="", placeholder = "",interactive=False)
197
+ update_file = gr.update(interactive=False)
198
+
199
+ return history, update_text, update_file, err_msg
200
+
201
+
202
+ def bot(state, history, image_path):
203
+
204
+ err_msg = ""
205
+ out_image_path = None
206
+ image_preview = False
207
+
208
+ ant_file = None
209
+
210
+ # セッション情報取得
211
+ client = state["client"]
212
+ assistant_id = state["assistant_id"]
213
+ thread_id = state["thread_id"]
214
+ last_msg_id = state["last_msg_id"]
215
+ history_outputs = state["tool_outputs"]
216
+
217
+ # メッセージ設定
218
+ message = client.beta.threads.messages.create(
219
+ thread_id=thread_id,
220
+ role="user",
221
+ content=history[-1][0],
222
+ )
223
+
224
+ # RUNスタート
225
+ run = client.beta.threads.runs.create(
226
+ thread_id=thread_id,
227
+ assistant_id=assistant_id,
228
+ # instructions=system_prompt
229
+ )
230
+
231
+ # "completed"となるまで繰り返す(指定秒おき)
232
+ for i in range(0, MAX_TRIAL, 1):
233
+
234
+ if i > 0:
235
+ time.sleep(INTER_SEC)
236
+
237
+ # 変数初期化
238
+ tool_outputs = []
239
+
240
+ # メッセージ受け取り
241
+ run = client.beta.threads.runs.retrieve(
242
+ thread_id=thread_id,
243
+ run_id=run.id
244
+ )
245
+
246
+ print(run.status)
247
+
248
+ if run.status == "requires_action": # 関数の結果の待ちの場合
249
+
250
+ print(run.required_action)
251
+
252
+ # tool_callsの各項目取得
253
+ tool_calls = run.required_action.submit_tool_outputs.tool_calls
254
+
255
+ print(len(tool_calls))
256
+ print(tool_calls)
257
+
258
+ # 一つ目だけ取得
259
+ tool_id = tool_calls[0].id
260
+ func_name = tool_calls[0].function.name
261
+ func_args = json.loads(tool_calls[0].function.arguments)
262
+
263
+ if func_name == "request_DallE3":
264
+
265
+ # ファイル名は現在時刻に
266
+ dt = datetime.datetime.now(ZoneInfo("Asia/Tokyo"))
267
+ image_name = dt.strftime("%Y%m%d%H%M%S") + ".png"
268
+
269
+ # ファイルパスは手動設定(誤りがないように)
270
+ out_image_path = IMG_FOLDER + "/" + image_name
271
+
272
+ # dall-e3のとき"image_path"は出力ファイルパス
273
+ func_args["image_path"] = out_image_path
274
+
275
+ elif func_name == "request_Vision":
276
+
277
+ if image_path is None:
278
+
279
+ # 画像がない場合エラーとなるようにする
280
+ func_args["image_path"] = ""
281
+
282
+ else:
283
+
284
+ # ファイルパスは手動設定
285
+ func_args["image_path"] = image_path
286
+
287
+ else:
288
+
289
+ # 関数名がないなら次へ
290
+ continue
291
+
292
+ # 関数を実行
293
+ func_output = func_action(state, func_name, func_args)
294
+
295
+ print(func_output)
296
+
297
+ # tool_outputリストに追加
298
+ tool_outputs.append({"tool_call_id": tool_id, "output": func_output})
299
+
300
+ # 複数の関数が必要な場合
301
+ if len(tool_calls) > 1:
302
+
303
+ # for i in range(len(tool_calls) - 1):
304
+ for i, tool_call in enumerate(tool_calls):
305
+
306
+ if i > 0:
307
+ # print(history_outputs[-(i+1)])
308
+ # # 最新のものからセット
309
+ # tool_outputs.append({"tool_call_id": history_outputs[-(i+1)]["tool_call_id"], "output": history_outputs[-(i+1)]["output"]})
310
+
311
+ # ダミー をセットする
312
+ tool_outputs.append({"tool_call_id": tool_calls.id, "output": {"answer" : ""}})
313
+
314
+ print(tool_outputs)
315
+
316
+ # 関数の出力を提出
317
+ run = client.beta.threads.runs.submit_tool_outputs(
318
+ thread_id=thread_id,
319
+ run_id=run.id,
320
+ tool_outputs=tool_outputs
321
+ )
322
+
323
+ if func_name == "request_DallE3":
324
+
325
+ # 画像の表示をする
326
+ image_preview = True
327
+
328
+ # セッション更新
329
+ history_outputs += tool_outputs
330
+ state["tool_outputs"] = history_outputs
331
+
332
+ else:
333
+
334
+ # 前回のメッセージより後を昇順で取り出す
335
+ messages = client.beta.threads.messages.list(
336
+ thread_id=thread_id,
337
+ after=last_msg_id,
338
+ order="asc"
339
+ )
340
+
341
+ print(messages)
342
+
343
+ # messageを取り出す
344
+ for msg in messages:
345
+
346
+ if msg.role == "assistant":
347
+
348
+ for content in msg.content:
349
+
350
+ res_text = ""
351
+ file_id = ""
352
+
353
+ cont_dict = content.model_dump() # 辞書型に変換
354
+
355
+ # 返答テキスト取得
356
+ res_text = cont_dict["text"].get("value")
357
+
358
+ if res_text != "":
359
+
360
+ # テキストを変換("sandbox:"などを」消す)
361
+ result = re.search(pt, res_text)
362
+
363
+ if result:
364
+
365
+ # パターン一致の場合はプロンプトだけ抜き出す
366
+ res_text = result.group(1)
367
+
368
+ # Chat画面更新
369
+ if history[-1][1] is not None:
370
+
371
+ # 新しい行を追加
372
+ history = history + [[None, res_text]]
373
+
374
+ else:
375
+
376
+ history[-1][1] = res_text
377
+
378
+ if image_preview:
379
+
380
+ print(out_image_path)
381
+
382
+ # Functionで画像を取得していた場合表示
383
+ history = history + [(None, (out_image_path,))]
384
+
385
+ image_preview = False
386
+
387
+ # 最終メッセージID更新
388
+ last_msg_id = msg.id
389
+
390
+ # Chatbotを返す(labelとhistoryを更新)
391
+ yield gr.Chatbot(label=run.status ,value=history), out_image_path, ant_file, err_msg
392
+
393
+ # セッションのメッセージID更新
394
+ state["last_msg_id"] = last_msg_id
395
+
396
+ # 完了なら終了
397
+ if run.status == "completed":
398
+
399
+ yield gr.Chatbot(label=run.status ,value=history), out_image_path, ant_file, err_msg
400
+ break
401
+
402
+
403
+ elif run.status == "failed":
404
+
405
+ # エラーとして終了
406
+ err_msg = "※メッセージ取得に失敗しました。"
407
+ yield gr.Chatbot(label=run.status ,value=history), out_image_path, ant_file, err_msg
408
+ break
409
+
410
+ elif i == MAX_TRIAL:
411
+
412
+ # エラーとして終了
413
+ err_msg = "※メッセージ取得の際にタイムアウトしました。"
414
+ yield gr.Chatbot(label=run.status ,value=history), out_image_path, ant_file, err_msg
415
+ break
416
+
417
+ else:
418
+ if i > 3:
419
+
420
+ # 作業中とわかるようにする
421
+ yield gr.Chatbot(label=run.status + " (Request:" + str(i) + ")" ,value=history), out_image_path, ant_file, err_msg
422
+
423
+
424
+ def func_action(state, func_name, func_args):
425
+
426
+ # セッションから情報取得
427
+ client = state["client"]
428
+ size = state["size"]
429
+ quality = state["quality"]
430
+ detail = state["detail"]
431
+
432
+ print("name:", func_name)
433
+ print("arguments:", func_args)
434
+
435
+ if func_name == "request_DallE3":
436
+
437
+ func_output = request_DallE3(
438
+ client,
439
+ func_args["prompt"],
440
+ size,
441
+ quality,
442
+ func_args["image_path"] # 出力パス
443
+ )
444
+
445
+ elif func_name == "request_Vision":
446
+
447
+ func_output = request_Vision(
448
+ client,
449
+ func_args["prompt"],
450
+ func_args["image_path"],
451
+ detail,
452
+ MAX_TOKENS
453
+ )
454
+
455
+ return func_output
456
+
457
+ def finally_proc():
458
+ """ 最終処理用関数 """
459
+
460
+ # テキストを使えるように
461
+ new_text = gr.update(interactive = True)
462
+
463
+ # 画像はリセット
464
+ new_image = gr.update(value=None, interactive = True)
465
+
466
+ return new_text, new_image
467
+
468
+
469
+ def clear_click(state):
470
+ """ クリアボタンクリック時 """
471
+
472
+ # セッションの一部をリセット
473
+ state["thread_id"] = ""
474
+ state["last_msg_id"] = ""
475
+ state["tool_outputs"] = []
476
+
477
+ return state
478
+
479
+ def make_archive():
480
+
481
+ shutil.make_archive("output_image", format='zip', root_dir=IMG_FOLDER)
482
+
483
+ return "output_image.zip", "下部の出力ファイルからダウンロードして下さい。"
484
+
485
+ def encode_image(image_path):
486
+ with open(image_path, "rb") as image_file:
487
+ return base64.b64encode(image_file.read()).decode("utf-8")
488
+
489
+ def make_prompt(prompt):
490
+
491
+ return "次のプロンプトで画像を作ってください「" + prompt + "」。"
492
+
493
+ # 画面構成
494
+ with gr.Blocks() as demo:
495
+
496
+ title = "<h2>GPT画像入出力対応チャット</h2>"
497
+ message = "<h3>・DallE3の画像生成とGPT-4 with Visionの画像解析が利用できます。<br>"
498
+ message += "・DallE3を利用する場合はプロンプト、GPT-4 Visionを利用する場合は画像とプロンプトを入力して下さい。<br>"
499
+ # message += "・テスト中でAPIKEY無しで動きます。<br>"
500
+ # message += "※現在画像から画像を作るimg2imgはできません。<br>"
501
+ # message += "・動画での紹介はこちら→https://www.youtube.com/watch?v=<br></h3>"
502
+ message += "</h3>"
503
+
504
+ gr.Markdown(title + message)
505
+
506
+ # セッションの宣言
507
+ state = gr.State({
508
+ # "system_prompt": SYS_PROMPT_DEFAULT,
509
+ "openai_key" : "",
510
+ "size" : "1024x1024",
511
+ "quality" : "standard",
512
+ "detail" : "low",
513
+ "client" : None,
514
+ "assistant_id" : "",
515
+ "thread_id" : "",
516
+ "last_msg_id" : "",
517
+ "tool_outputs" : []
518
+ })
519
+
520
+ with gr.Tab("Chat画面") as chat:
521
+
522
+ # 各コンポーネント定義
523
+ chatbot = gr.Chatbot(label="チャット画面")
524
+ text_msg = gr.Textbox(label="プロンプト", placeholder = PLACEHOLDER)
525
+ text_dummy = gr.Textbox(visible=False)
526
+ gr.Examples(label="サンプルプロンプト", examples=examples, inputs=text_dummy, outputs=text_msg, fn=make_prompt, cache_examples=True)
527
+
528
+ with gr.Row():
529
+ image = gr.Image(label="ファイルアップロード", type="filepath",interactive = True)
530
+ out_image = gr.Image(label="出力画像", type="filepath", interactive = False)
531
+ with gr.Row():
532
+ btn = gr.Button(value="送信")
533
+ btn_dl = gr.Button(value="画像の一括ダウンロード") # 保留中
534
+ btn_clear = gr.ClearButton(value="リセット", components=[chatbot, text_msg, image ,out_image])
535
+
536
+ sys_msg = gr.Textbox(label="システムメッセージ", interactive = False)
537
+ # out_text = gr.Textbox(label="出力テキスト", lines = 5, interactive = False)
538
+ out_file = gr.File(label="出力ファイル", type="filepath",interactive = False)
539
+
540
+ # 送信ボタンクリック時の処理
541
+ bc = btn.click(init, [state, text_msg, image], [state, sys_msg], queue=False).success(
542
+ raise_exception, sys_msg, None).success(
543
+ add_history, [chatbot, text_msg, image], [chatbot, text_msg, image, sys_msg], queue=False).success(
544
+ bot, [state, chatbot, image],[chatbot, out_image, out_file, sys_msg]).then(
545
+ finally_proc, None, [text_msg, image], queue=False
546
+ )
547
+ btn_dl.click(make_archive, None, [out_file, sys_msg])
548
+ # クリア時でもセッションの設定(OpenAIKeyなどは残す)
549
+ btn_clear.click(clear_click, state, state)
550
+
551
+ # テキスト入力Enter時の処理
552
+ # txt_msg = text_msg.submit(respond, inputs=[text_msg, image, chatbot], outputs=[text_msg, image, chatbot])
553
+
554
+ with gr.Tab("設定") as set:
555
+
556
+ gr.Markdown("<h4>OpenAI設定</h4>")
557
+ with gr.Row():
558
+ openai_key = gr.Textbox(label="OpenAI API Key", visible=True) # テスト中は表示せず
559
+ # system_prompt = gr.Textbox(value = SYS_PROMPT_DEFAULT,lines = 5, label="Custom instructions", interactive = True)
560
+ gr.Markdown("<h4>DaLL-E3用設定</h4>")
561
+ with gr.Row():
562
+ size = gr.Dropdown(label="サイズ", choices=["1024x1024"], value = "1024x1024", interactive = True)
563
+ quality = gr.Dropdown(label="クオリティ", choices=["standard"], value = "standard", interactive = True)
564
+ gr.Markdown("<h4>Vison用設定</h4>")
565
+ with gr.Row():
566
+ detail = gr.Dropdown(label="コード出力", choices=["low", "high" , "auto"], value = "low", interactive = True)
567
+
568
+ # 設定タブからChatタブに戻った時の処理
569
+ chat.select(set_state, [openai_key, size, quality, detail, state], state)
570
+
571
+ with gr.Tab("注意事項") as notes:
572
+ caution = "現在Assistant APIはβ版でのリリースとなっています。<br>"
573
+ caution += "そのためか一部のファイルのアップロードが上手くいかないため、制限をかけています。<br>"
574
+ caution += "(現在アップできるファイル形式は.txtと.csvのみ)<br>"
575
+ caution += "本来はPDFなども利用できるはずなので、今後更新したいと思います。また日本語文字化けも調査中です。"
576
+
577
+ gr.Markdown("<h3>" + caution + "</h3>")
578
+
579
+
580
+ demo.queue()
581
+ demo.launch(debug=True)
582
+
583
+ def request_DallE3(client, prompt, size, quality, out_image_path):
584
+ """ DallE3を呼び出す """
585
+
586
+ err_msg = ""
587
+
588
+ imgage_path = "/content/sample_datacat1.png"
589
+
590
+ dalle3_result = {
591
+ "imgage_path" : imgage_path,
592
+ "error_message" : err_msg
593
+ }
594
+
595
+ return json.dumps(dalle3_result)
596
+
597
+
598
+ def request_Vision(prompt, detail, image_path):
599
+ """ GPT4 Visionを呼び出す """
600
+
601
+ response_text = "この画像は驚いた表情をしている人物を写した写真です。"
602
+
603
+ vision_result = {
604
+ "answer" : response_text
605
+ }
606
+
607
+ return json.dumps(vision_result)
608
+
609
+ def request_DallE3(client, prompt, size, quality, out_image_path):
610
+
611
+ err_msg = ""
612
+
613
+ try:
614
+
615
+ response = client.images.generate(
616
+ model="dall-e-3",
617
+ prompt=prompt,
618
+ size=size,
619
+ quality=quality,
620
+ n=1,
621
+ response_format="b64_json"
622
+ )
623
+
624
+ print(response.data[0])
625
+
626
+ # データを受け取りデコード
627
+ image_data_json = response.data[0].b64_json
628
+ image_data = base64.b64decode(image_data_json)
629
+
630
+ # 画像として扱えるように保存
631
+ image_stream = BytesIO(image_data)
632
+ image = Image.open(image_stream)
633
+ image.save(out_image_path)
634
+
635
+ except BadRequestError as e:
636
+ print(e)
637
+ out_image_path = ""
638
+ err_msg = "リクエストエラーです。著作権侵害などプロンプトを確認して下さい。"
639
+ except Exception as e:
640
+ print(e)
641
+ out_image_path = ""
642
+ err_msg = "その他のエラーが発生しました。"
643
+
644
+ finally:
645
+
646
+ # 結果をJSONで返す
647
+ dalle3_result = {
648
+ "image_path" : out_image_path,
649
+ "error_message" : err_msg
650
+ }
651
+ return json.dumps(dalle3_result)
652
+
653
+
654
+ def request_Vision(client, prompt, image_path, detail, max_tokens):
655
+
656
+ response_text = ""
657
+ err_msg = ""
658
+
659
+ if image_path == "":
660
+
661
+ # 画像がない時はエラーとして返す
662
+ vision_result = {"answer" : "", "error_message" : "画像をセットして下さい。"}
663
+ return json.dumps(vision_result)
664
+
665
+ try:
666
+
667
+ # 画像をbase64に変換
668
+ image = encode_image(image_path)
669
+
670
+ # メッセージの作成
671
+ messages = [
672
+ {
673
+ "role": "user",
674
+ "content": [
675
+ {"type": "text", "text": prompt},
676
+ {
677
+ "type": "image_url",
678
+ "image_url": {
679
+ "url": f"data:image/jpeg;base64,{image}",
680
+ "detail": detail,
681
+ }
682
+ },
683
+ ],
684
+ }
685
+ ]
686
+
687
+ # gpt-4-visionに問い合わせて回答を表示
688
+ response = client.chat.completions.create(
689
+ model="gpt-4-vision-preview", # Visionはこのモデル指定
690
+ messages=messages,
691
+ max_tokens=max_tokens,
692
+ )
693
+
694
+ response_text = response.choices[0].message.content
695
+
696
+ print(response_text)
697
+
698
+ except BadRequestError as e:
699
+ print(e)
700
+ err_msg = "リクエストエラーです。画像がポリシー違反でないか確認して下さい。"
701
+ except Exception as e:
702
+ print(e)
703
+ err_msg = "その他のエラーが発生しました。"
704
+
705
+ finally:
706
+
707
+ # 結果をJSONで返す
708
+ vision_result = {
709
+ "answer" : response_text,
710
+ "error_message" : err_msg
711
+ }
712
+ return json.dumps(vision_result)
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # gradio==4.2.0
2
+ openai==1.2.4
sample_data/dummy.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ あああ