Spaces:
Running
Running
import os | |
from openai import OpenAI, BadRequestError | |
import gradio as gr | |
from datetime import datetime | |
from zoneinfo import ZoneInfo | |
import shutil | |
from io import BytesIO | |
import base64 | |
from PIL import Image | |
PLACEHOLDER = "ここに画像を生成するための文章を入力してください…" | |
# 設定用リスト | |
size_list = ["1024x1024" ,"1024x1792" ,"1792x1024"] | |
quality_list = ["standard" ,"hd"] | |
style_list = ["vivid", "natural"] | |
# サンプルプロンプト用 | |
examples = ["1980s anime girl with straight bob-cut in school uniform, roughly drawn drawing" | |
, "a minimalisit logo for a sporting goods company"] | |
def set_state(state, openai_key, size, quality, style): | |
state["openai_key"]= openai_key | |
state["size"] = size | |
state["quality"] = quality | |
state["style"] = style | |
return state | |
def request_dalle(client, prompt, size, quality, style, image_path): | |
err_msg = "" | |
revised_prompt = "" | |
try: | |
response = client.images.generate( | |
model="dall-e-3", | |
prompt=prompt, | |
size=size, | |
quality=quality, | |
style=style, | |
n=1, | |
response_format="b64_json" | |
) | |
# データを受け取りデコード | |
image_data_json = response.data[0].b64_json | |
image_data = base64.b64decode(image_data_json) | |
# 画像として扱えるように保存 | |
image_stream = BytesIO(image_data) | |
image = Image.open(image_stream) | |
image.save(image_path) | |
# dalle内部のプロンプト | |
revised_prompt = response.data[0].revised_prompt | |
except BadRequestError as e: | |
print(e) | |
out_image_path = "" | |
err_msg = "リクエストエラーです。著作権侵害などプロンプトを確認して下さい。" | |
except Exception as e: | |
print(e) | |
out_image_path = "" | |
err_msg = "その他のエラーが発生しました。" | |
finally: | |
return err_msg, revised_prompt | |
def create_image(state, text): | |
err_msg = "" | |
user_id = state["user_id"] | |
client = state["client"] | |
size = state["size"] | |
quality = state["quality"] | |
style = state["style"] | |
# OpenAIキーチェック | |
if state["openai_key"] == "": | |
err_msg = "OpenAIキーを入力してください。(設定タブ)" | |
return text, None, "", err_msg | |
# 入力チェック | |
if text.strip() == "": | |
err_msg = "プロンプトを入力してください。" | |
return text, None, "", err_msg | |
if user_id == "": | |
# IDとして現在時刻をセット | |
dt = datetime.now(ZoneInfo("Asia/Tokyo")) | |
user_id = dt.strftime("%Y%m%d%H%M%S") | |
# ユーザIDでフォルダ作成 | |
os.makedirs(user_id, exist_ok=True) | |
state["user_id"] = user_id | |
# ファイル名は現在時刻に | |
dt = datetime.now(ZoneInfo("Asia/Tokyo")) | |
image_name = dt.strftime("%Y%m%d%H%M%S") + ".png" | |
# ファイルパスは手動設定(誤りがないように) | |
image_path = user_id + "/" + image_name | |
if client is None: | |
os.environ["OPENAI_API_KEY"] = state["openai_key"] | |
os.environ["OPENAI_API_KEY"] = os.environ["TEST_API_KEY"] | |
# クライアント作成 | |
client = OpenAI() | |
# client作成後は消す | |
os.environ["OPENAI_API_KEY"] = "" | |
state["client"] = client | |
# テスト用 | |
# image_path = "cat.png" | |
return_msg, prompt = request_dalle(client, text, size, quality, style, image_path) | |
if return_msg == "": | |
return "", image_path, prompt, "" | |
else: | |
err_msg = "画像の作成に失敗しました。\n" + return_msg | |
return text, None, "", err_msg | |
def make_archive(state): | |
""" 画像のZIP化・一括ダウンロード用関数 """ | |
dir = state["user_id"] | |
if dir is None or dir == "": | |
return None, "" | |
if len(os.listdir(dir)) == 0: | |
return None, "" | |
shutil.make_archive(dir, format='zip', root_dir=dir) | |
return dir + ".zip", "下部の出力ファイルからダウンロードして下さい。" | |
with gr.Blocks() as demo: | |
gr.Markdown("設定タブでOpenAIキーを入力してください。") | |
# セッションの宣言 | |
state = gr.State({ | |
"openai_key" : "", | |
"client" : None, | |
"user_id" : "", | |
"size" : "", | |
"quality" : "", | |
"style" : "", | |
}) | |
with gr.Tab("whisperを利用する") as maintab: | |
# 出力画像 | |
out_image = gr.Image(label="生成画像", type="filepath", interactive = False) | |
# 各コンポーネント定義 | |
text = gr.Textbox(label="プロンプト", lines=4, placeholder=PLACEHOLDER) | |
# ボタン類 | |
with gr.Row(): | |
btn = gr.Button("画像作成") | |
btn_dl = gr.Button(value="画像の一括ダウンロード") | |
# btn_clear = gr.ClearButton(value="リセット", components=[chatbot, text_msg]) | |
out_prompt = gr.Text(label="Dalle-3内部プロンプト") | |
sys_msg = gr.Text(label="システムメッセージ") | |
out_file = gr.File(label="出力ファイル", type="filepath",interactive = False) | |
btn.click(create_image, [state, text], [text, out_image, out_prompt, sys_msg]) | |
btn_dl.click(make_archive, state, [out_file, sys_msg]) | |
with gr.Tab("設定") as settab: | |
openai_key = gr.Textbox(label="OpenAI API Key", interactive = True) | |
size = gr.Dropdown(label="サイズ", choices=size_list, value = "1024x1024", interactive = True) | |
quality = gr.Dropdown(label="クオリティ", choices=quality_list, value = "standard", interactive = True) | |
style = gr.Dropdown(label="スタイル", choices=style_list, value = "vivid", interactive = True) | |
# 設定変更時 | |
maintab.select(set_state, [state, openai_key, size, quality, style], state) | |
with gr.Tab("利用上の注意"): | |
caution = "・1枚当たりの料金はサイズ:1024x1024で0.04ドル、それ以外のサイズは0.080ドルです。<br>" | |
caution += "・こちらはクオリティがStandardの場合で、hdの場合は1.5~2倍とかかります。" | |
caution += "詳細→https://openai.com/pricing" | |
gr.Markdown("<h3>" + caution + "</h3>") | |
if __name__ == '__main__': | |
demo.queue() | |
demo.launch(debug=True) | |