nekoniii3 commited on
Commit
127bd92
1 Parent(s): da933c0
Files changed (2) hide show
  1. app.py +211 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from openai import OpenAI, BadRequestError
3
+ import gradio as gr
4
+ from datetime import datetime
5
+ from zoneinfo import ZoneInfo
6
+ import shutil
7
+ from io import BytesIO
8
+ import base64
9
+ from PIL import Image
10
+
11
+ PLACEHOLDER = "ここに画像を生成するための文章を入力してください…"
12
+
13
+ # 設定用リスト
14
+ size_list = ["1024x1024" ,"1024x1792" ,"1792x1024"]
15
+ quality_list = ["standard" ,"hd"]
16
+ style_list = ["vivid", "natural"]
17
+
18
+ # サンプルプロンプト用
19
+ examples = ["1980s anime girl with straight bob-cut in school uniform, roughly drawn drawing"
20
+ , "a minimalisit logo for a sporting goods company"]
21
+
22
+ def set_state(state, openai_key, size, quality, style):
23
+
24
+ state["openai_key"]= openai_key
25
+ state["size"] = size
26
+ state["quality"] = quality
27
+ state["style"] = style
28
+
29
+ return state
30
+
31
+ def request_dalle(client, prompt, size, quality, style, image_path):
32
+
33
+ err_msg = ""
34
+ revised_prompt = ""
35
+
36
+ try:
37
+
38
+ response = client.images.generate(
39
+ model="dall-e-3",
40
+ prompt=prompt,
41
+ size=size,
42
+ quality=quality,
43
+ style=style,
44
+ n=1,
45
+ response_format="b64_json"
46
+ )
47
+
48
+ # データを受け取りデコード
49
+ image_data_json = response.data[0].b64_json
50
+ image_data = base64.b64decode(image_data_json)
51
+
52
+ # 画像として扱えるように保存
53
+ image_stream = BytesIO(image_data)
54
+ image = Image.open(image_stream)
55
+ image.save(image_path)
56
+
57
+ # dalle内部のプロンプト
58
+ revised_prompt = response.data[0].revised_prompt
59
+
60
+ except BadRequestError as e:
61
+ print(e)
62
+ out_image_path = ""
63
+ err_msg = "リクエストエラーです。著作権侵害などプロンプトを確認して下さい。"
64
+ except Exception as e:
65
+ print(e)
66
+ out_image_path = ""
67
+ err_msg = "その他のエラーが発生しました。"
68
+ finally:
69
+ return err_msg, revised_prompt
70
+
71
+ def create_image(state, text):
72
+
73
+ err_msg = ""
74
+
75
+ user_id = state["user_id"]
76
+ client = state["client"]
77
+ size = state["size"]
78
+ quality = state["quality"]
79
+ style = state["style"]
80
+
81
+ # OpenAIキーチェック
82
+ if state["openai_key"] == "":
83
+
84
+ err_msg = "OpenAIキーを入力してください。(設定タブ)"
85
+
86
+ return text, None, "", err_msg
87
+
88
+ # 入力チェック
89
+ if text.strip() == "":
90
+
91
+ err_msg = "プロンプトを入力してください。"
92
+
93
+ return text, None, "", err_msg
94
+
95
+ if user_id == "":
96
+
97
+ # IDとして現在時刻をセット
98
+ dt = datetime.now(ZoneInfo("Asia/Tokyo"))
99
+ user_id = dt.strftime("%Y%m%d%H%M%S")
100
+
101
+ # ユーザIDでフォルダ作成
102
+ os.makedirs(user_id, exist_ok=True)
103
+
104
+ state["user_id"] = user_id
105
+
106
+ # ファイル名は現在時刻に
107
+ dt = datetime.now(ZoneInfo("Asia/Tokyo"))
108
+ image_name = dt.strftime("%Y%m%d%H%M%S") + ".png"
109
+
110
+ # ファイルパスは手動設定(誤りがないように)
111
+ image_path = user_id + "/" + image_name
112
+
113
+ if client is None:
114
+
115
+ os.environ["OPENAI_API_KEY"] = state["openai_key"]
116
+ os.environ["OPENAI_API_KEY"] = os.environ["TEST_API_KEY"]
117
+ # クライアント作成
118
+ client = OpenAI()
119
+
120
+ # client作成後は消す
121
+ os.environ["OPENAI_API_KEY"] = ""
122
+
123
+ state["client"] = client
124
+
125
+ # テスト用
126
+ # image_path = "cat.png"
127
+
128
+ return_msg, prompt = request_dalle(client, text, size, quality, style, image_path)
129
+
130
+ if return_msg == "":
131
+
132
+ return "", image_path, prompt, ""
133
+
134
+ else:
135
+ err_msg = "画像の作成に失敗しました。\n" + return_msg
136
+
137
+ return text, None, "", err_msg
138
+
139
+ def make_archive(state):
140
+ """ 画像のZIP化・一括ダウンロード用関数 """
141
+
142
+ dir = state["user_id"]
143
+
144
+ if dir is None or dir == "":
145
+
146
+ return None, ""
147
+
148
+ if len(os.listdir(dir)) == 0:
149
+
150
+ return None, ""
151
+
152
+ shutil.make_archive(dir, format='zip', root_dir=dir)
153
+
154
+ return dir + ".zip", "下部の出力ファイルからダウンロードして下さい。"
155
+
156
+ with gr.Blocks() as demo:
157
+
158
+ gr.Markdown("設定タブでOpenAIキーを入力してください。")
159
+
160
+ # セッションの宣言
161
+ state = gr.State({
162
+ "openai_key" : "",
163
+ "client" : None,
164
+ "user_id" : "",
165
+ "size" : "",
166
+ "quality" : "",
167
+ "style" : "",
168
+ })
169
+
170
+ with gr.Tab("whisperを利用する") as maintab:
171
+
172
+ # 出力画像
173
+ out_image = gr.Image(label="生成画像", type="filepath", interactive = False)
174
+
175
+ # 各コンポーネント定義
176
+ text = gr.Textbox(label="プロンプト", lines=4, placeholder=PLACEHOLDER)
177
+
178
+ # ボタン類
179
+ with gr.Row():
180
+ btn = gr.Button("画像作成")
181
+ btn_dl = gr.Button(value="画像の一括ダウンロード")
182
+ # btn_clear = gr.ClearButton(value="リセット", components=[chatbot, text_msg])
183
+
184
+ out_prompt = gr.Text(label="Dalle-3内部プロンプト")
185
+ sys_msg = gr.Text(label="システムメッセージ")
186
+ out_file = gr.File(label="出力ファイル", type="filepath",interactive = False)
187
+
188
+ btn.click(create_image, [state, text], [text, out_image, out_prompt, sys_msg])
189
+ btn_dl.click(make_archive, state, [out_file, sys_msg])
190
+
191
+ with gr.Tab("設定") as settab:
192
+ openai_key = gr.Textbox(label="OpenAI API Key", interactive = True)
193
+ size = gr.Dropdown(label="サイズ", choices=size_list, value = "1024x1024", interactive = True)
194
+ quality = gr.Dropdown(label="クオリティ", choices=quality_list, value = "standard", interactive = True)
195
+ style = gr.Dropdown(label="スタイル", choices=style_list, value = "vivid", interactive = True)
196
+
197
+ # 設定変更時
198
+ maintab.select(set_state, [state, openai_key, size, quality, style], state)
199
+
200
+ with gr.Tab("利用上の注意"):
201
+
202
+ caution = "・1枚当たりの料金はサイズ:1024x1024で0.04ドル、それ以外のサイズは0.080ドルです。<br>"
203
+ caution += "・こちらはクオリティがStandardの場合で、hdの場合は1.5~2倍とかかります。"
204
+ caution += "詳細→https://openai.com/pricing"
205
+ gr.Markdown("<h3>" + caution + "</h3>")
206
+
207
+
208
+ if __name__ == '__main__':
209
+
210
+ demo.queue()
211
+ demo.launch(debug=True)
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # gradio==4.0.2
2
+ openai==1.6.1