Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import os | |
from transformers import pipeline | |
from collections import defaultdict | |
import torch | |
from typing import Iterable | |
from gradio.themes.base import Base | |
from gradio.themes.utils import colors, fonts, sizes | |
from PIL import Image | |
import datetime | |
from diffusers import DiffusionPipeline | |
import random | |
import numpy as np | |
from huggingface_hub import login | |
login(token=os.environ.get("TOKEN")) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
dtype = torch.float16 | |
print('script starting up...') | |
# never do this" | |
print('loading model...') | |
auth_token = True | |
pipe = DiffusionPipeline.from_pretrained("Mitsua/mitsua-likes", token=auth_token, trust_remote_code=True).to(device, dtype=dtype) | |
bad_words = os.environ.get("BAD_WORDS").split(",") | |
PUBLIC_CHARACTER_NAMES_JA = set([ | |
'絵藍ミツア', | |
'つくよみちゃん', | |
'ずんだもん', | |
'紡ネン', | |
'東北ずん子', | |
'東北イタコ', | |
'東北きりたん', | |
'四国めたん', | |
'中国うさぎ', | |
'原型ずんだもん', | |
'九州そら', | |
'大江戸ちゃんこ', | |
'フィーちゃん', | |
'19歳つくよみちゃん' | |
]) | |
PUBLIC_CHARACTER_NAMES_EN = set([ | |
'elanmitsua', | |
'tsukuyomichan', | |
'zundamon', | |
'tsumugi nen', | |
'touhoku zunko', | |
'touhoku itako', | |
'touhoku kiritan', | |
'shikoku metan', | |
'chugoku usagi', | |
'zundamon_original', | |
'kyushu_sora', | |
'kansai_shinobi', | |
'ccd-0500', | |
'19sai_tsukuyomichan' | |
]) | |
PUBLIC_CHARACTER_LINKS = { | |
'絵藍ミツア': "https://elanmitsua.com/", | |
'原型ずんだもん': "https://zunko.jp/", | |
'九州そら': "https://zunko.jp/", | |
'大江戸ちゃんこ': "https://zunko.jp/", | |
'フィーちゃん': "https://u-stella.co.jp/gallery/ccd-0500/", | |
'19歳つくよみちゃん': "https://tyc.rei-yumesaki.net/", | |
'つくよみちゃん': "https://tyc.rei-yumesaki.net/", | |
'ずんだもん': "https://zunko.jp/", | |
'紡ネン': "https://tsumuginen.com/guideline", | |
'東北ずん子': "https://zunko.jp/", | |
'東北イタコ': "https://zunko.jp/", | |
'東北きりたん': "https://zunko.jp/", | |
'四国めたん': "https://zunko.jp/", | |
'中国うさぎ': "https://zunko.jp/" | |
} | |
license_str_ja_orig = """ | |
# 画像ライセンス : [Mitsua Likes 表示-非営利](https://elanmitsua.notion.site/Mitsua-Likes-15baa85a9b278005bba5f30866a35f48) | |
## クレジット表示必須 | |
- 生成物にMitsua Likesのクレジットを合理的な方法で表示しなければなりません。クレジットは以下のいずれかを指します。 | |
- Generated by Mitsua Likes | |
- 画像生成:Mitsua Likes | |
## 非商用限定 | |
- 生成物の商用利用不可 (自身の創造的目的のための個人商用利用を除く) | |
- 企業商用利用については[お問い合わせ](https://abstractengine.ltd/#contact)までお問い合わせください | |
## 禁止事項(抜粋) | |
- 差別・誹謗中傷・侮辱・名誉棄損 | |
- 第三者の知的財産権・プライバシーの侵害 | |
- 虚偽の情報の流布 | |
- 営利目的で素材として販売する行為 | |
- 機械学習を目的とした、データセット作成行為 | |
- その他、法令、公序良俗に違反する/おそれがある反社会的行為 | |
""" | |
license_str_en_orig = """ | |
# Image License : [Mitsua Likes BY-NC](https://elanmitsua.notion.site/Mitsua-Likes-Attribution-NonCommercial-License-15baa85a9b278038be5dc7f47a9c26cc) | |
## Attribution required | |
- You must give appropriate credit of "Mitsua Likes" for sharing generated result. “Credits for Mitsua Likes” means displaying one of the following statements: | |
- Generated by Mitsua Likes | |
- 画像生成:Mitsua Likes | |
## Non-Commercial use only | |
- Non-Commercial use only, except for individual commercial use of creative purpose. | |
- For corporation commercial use, please contact at [this contact form](https://abstractengine.ltd/en/#contact) | |
## Prohibited Acts | |
- Acts that discriminate against, defame, or insult MITSUA Project or third parties, damaging their honor or credibility. | |
- Acts that infringe or may infringe on the intellectual property rights or privacy of MITSUA Project or third parties. | |
- Disseminating information or content that unjustly harms the interests of MITSUA Project or third parties. | |
- Disseminating false information or content. | |
- Distributing or selling the Mitsua Likes Generated Data, etc., as materials for commercial purposes. | |
- Other antisocial acts that violate or may violate laws, regulations, or public order and morals. | |
- Creating a dataset composed primarily of Generated Data for the purpose of machine learning. | |
""" | |
def infer_impl( | |
prompt, | |
style, | |
lang='ja', | |
negative_prompt="elan doodle", | |
seed=42, | |
randomize_seed=True, | |
ar="1:1", | |
# width=672, | |
# height=896, | |
guidance_scale=5.0, | |
num_inference_steps=40, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
now = datetime.datetime.now() | |
style_to_prompt = { | |
"sensei art / 先生アート" : "先生アートsensei artイラスト", | |
"digital illustration / デジタルイラスト" : "デジタルイラスト、digital illustration", | |
"analog illustration / アナログイラスト" : "アナログ風、illustration", | |
"3d cg" : "3d cg", | |
"artworks / 芸術作品" : "芸術作品artworks, paintings", | |
# "none": "", | |
} | |
style2negative_prompt = { | |
"sensei art / 先生アート" : "photo", | |
"digital illustration / デジタルイラスト" : "photo", | |
"analog illustration / アナログイラスト" : "vrm, cg, photo", | |
"3d cg" : "photo", | |
"artworks / 芸術作品" : "vrm, cg, photo", | |
# "none" : "", | |
} | |
if any([a in prompt for a in bad_words]): | |
return None, "## <span style='color:orangered'>ERROR: Invalid prompt / 不適切なプロンプト </span>", None, seed | |
if randomize_seed: | |
seed = random.randint(0, np.iinfo(np.int32).max) | |
generator = torch.Generator(device=device).manual_seed(seed) | |
if style_to_prompt[style] != "": | |
prompt = style_to_prompt[style] + ", " + prompt | |
# prompt = prompt + " " + style_to_prompt[style] | |
negative_prompt = negative_prompt + ", " + style2negative_prompt[style] | |
width, height = 768, 768 | |
if ar == '16:9': | |
width, height = 1024, 576 | |
elif ar == '4:3': | |
width, height = 896, 672 | |
elif ar == '1:1': | |
width, height = 768, 768 | |
elif ar == '3:4': | |
width, height = 672, 896 | |
elif ar == '9:16': | |
width, height = 576, 1024 | |
ret = pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
guidance_scale=guidance_scale, | |
guidance_rescale=0.7, | |
num_inference_steps=num_inference_steps, | |
width=width, | |
height=height, | |
generator=generator, | |
) | |
image = ret.images[0] | |
detected_public_fictional_characters = ret.detected_public_fictional_characters[0] | |
detected_public_fictional_characters_info=ret.detected_public_fictional_characters_info[0] | |
license_str_ja = license_str_ja_orig | |
license_str_en = license_str_en_orig | |
def get_link_str(k): | |
return f"[{k}]({PUBLIC_CHARACTER_LINKS[k]})" if k in PUBLIC_CHARACTER_LINKS else k | |
if len(detected_public_fictional_characters) > 0: | |
license_str_ja += f"## <span style='color:orangered'>類似による追加の制約: {','.join(detected_public_fictional_characters)}</span>\n" | |
license_str_en += f"## <span style='color:orangered'>Similarity Restriction: {','.join(detected_public_fictional_characters)}</span>\n" | |
for k,v in detected_public_fictional_characters_info.items(): | |
if k in detected_public_fictional_characters: | |
license_str_ja += f"- 「{get_link_str(k)}」の利用規約または二次創作ガイドラインに従う必要があります\n" | |
license_str_en += f"- Abide by the terms or the derivative guideline of \"{get_link_str(k)}\" required.\n" | |
if any([a in prompt for a in PUBLIC_CHARACTER_NAMES_JA if a not in detected_public_fictional_characters]): | |
license_str_ja += f"## <span style='color:orangered'>プロンプトによる追加の制約: {','.join([a for a in PUBLIC_CHARACTER_NAMES_JA if a in prompt and a not in detected_public_fictional_characters])}</span>\n" | |
license_str_en += f"## <span style='color:orangered'>Prompt Restriction: {','.join([a for a in PUBLIC_CHARACTER_NAMES_JA if a in prompt and a not in detected_public_fictional_characters])}</span>\n" | |
for k in PUBLIC_CHARACTER_NAMES_JA: | |
if k in prompt and k not in detected_public_fictional_characters: | |
license_str_ja += f"- 「{get_link_str(k)}」の利用規約または二次創作ガイドラインに従う必要があります\n" | |
license_str_en += f"- Abide by the terms or the derivative guideline of \"{get_link_str(k)}\" required.\n" | |
if any([a in prompt for a in PUBLIC_CHARACTER_NAMES_EN if a not in detected_public_fictional_characters]): | |
license_str_ja += f"## <span style='color:orangered'>プロンプトによる追加の制約: {','.join([a for a in PUBLIC_CHARACTER_NAMES_EN if a in prompt and a not in detected_public_fictional_characters])}</span>\n" | |
license_str_en += f"## <span style='color:orangered'>Prompt Restriction: {','.join([a for a in PUBLIC_CHARACTER_NAMES_EN if a in prompt and a not in detected_public_fictional_characters])}</span>\n" | |
for k in PUBLIC_CHARACTER_NAMES_EN: | |
if k in prompt and k not in detected_public_fictional_characters: | |
license_str_ja += f"- 「{get_link_str(k)}」の利用規約または二次創作ガイドラインに従う必要があります\n" | |
license_str_en += f"- Abide by the terms or the derivative guideline of \"{get_link_str(k)}\" required.\n" | |
names = [c for c in detected_public_fictional_characters_info if c not in detected_public_fictional_characters and not c in prompt] | |
if len(names) > 0: | |
license_str_ja += f"## <span style='color:darkorange'>類似可能性の通知: {','.join(names)}</span>\n" | |
license_str_en += f"## <span style='color:darkorange'>Possible Similarity Notice: {','.join(names)}</span>\n" | |
for k in names: | |
license_str_ja += f"- 「{get_link_str(k)}」に類似している可能性があります。公式キャラクターデザインをご確認の上、「{get_link_str(k)}」の利用規約または二次創作ガイドラインに従うかを検討してください\n" | |
license_str_en += f"- Image might look like \"{get_link_str(k)}\". Please confirm their character design and consider to abide by their derivative guideline.\n" | |
license_str = license_str_ja if lang.startswith('ja') else license_str_en | |
likes_logo_img = Image.open("LikesHFFooter.png") | |
likes_logo_top_img = Image.open("MitsuaLikesLogoWhite.png") | |
image_w_logo = Image.new("RGB", (image.width, image.height + likes_logo_img.height), (255, 255, 255)) | |
image_w_logo.paste(image) | |
image_w_logo.paste(likes_logo_top_img, (image.width - likes_logo_top_img.width,0), likes_logo_top_img) | |
image_w_logo.paste(likes_logo_img, (0, image.height)) | |
return image_w_logo, license_str, detected_public_fictional_characters_info, seed | |
def infer( | |
prompt, | |
style, | |
lang='ja', | |
negative_prompt="elan doodle", | |
seed=42, | |
randomize_seed=True, | |
ar="1:1", | |
# width=672, | |
# height=896, | |
guidance_scale=5.0, | |
num_inference_steps=40, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
if any([a in prompt for a in bad_words]): | |
return None, "## <span style='color:orangered'>ERROR: Invalid prompt / 不適切なプロンプト </span>", None, seed | |
yield None, None, None, seed | |
yield infer_impl(prompt, style, lang, negative_prompt, seed, randomize_seed, ar, guidance_scale, num_inference_steps, progress) | |
# Seafoam theme based on | |
# https://huggingface.co/spaces/gradio/seafoam | |
# https://github.com/gradio-app/gradio | |
class ModifiedSeafoam(Base): | |
def __init__( | |
self, | |
*, | |
primary_hue: colors.Color | str = colors.cyan, | |
secondary_hue: colors.Color | str = colors.sky, | |
neutral_hue: colors.Color | str = colors.sky, | |
spacing_size: sizes.Size | str = sizes.spacing_sm, | |
radius_size: sizes.Size | str = sizes.radius_lg, | |
text_size: sizes.Size | str = sizes.text_lg, | |
font: fonts.Font | |
| str | |
| Iterable[fonts.Font | str] = ( | |
fonts.GoogleFont("Quicksand"), | |
"ui-sans-serif", | |
"sans-serif", | |
), | |
font_mono: fonts.Font | |
| str | |
| Iterable[fonts.Font | str] = ( | |
fonts.GoogleFont("IBM Plex Mono"), | |
"ui-monospace", | |
"monospace", | |
), | |
): | |
super().__init__( | |
primary_hue=primary_hue, | |
secondary_hue=secondary_hue, | |
neutral_hue=neutral_hue, | |
spacing_size=spacing_size, | |
radius_size=radius_size, | |
text_size=text_size, | |
font=font, | |
font_mono=font_mono, | |
) | |
super().set( | |
body_background_fill="linear-gradient(0deg, *primary_200, *primary_50)", | |
body_background_fill_dark="linear-gradient(0deg, *secondary_900, *secondary_950)", | |
button_primary_background_fill="linear-gradient(0deg, *primary_400, *secondary_400)", | |
button_primary_background_fill_hover="linear-gradient(0deg, *primary_400, *secondary_200)", | |
button_primary_background_fill_dark="linear-gradient(0deg, *primary_800, *secondary_800)", | |
button_primary_background_fill_hover_dark="linear-gradient(0deg, *primary_800, *secondary_600)", | |
button_primary_text_color="*primary_900", | |
button_primary_text_color_hover="*primary_800", | |
slider_color="*secondary_300", | |
slider_color_dark="*primary_600", | |
block_title_text_weight="600", | |
block_border_width="3px", | |
block_shadow="*shadow_drop_lg", | |
# button_shadow="*shadow_drop_lg", | |
# button_large_padding="32px", | |
) | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser(description='Mitsua Likes Image Generator') | |
parser.add_argument('--share', action="store_true") | |
args = parser.parse_args() | |
ui_text = { | |
'title': 'Mitsua Likes Demo', | |
'description': "Let's enjoy image generation!", | |
'dropdown_label': 'Translation Direction', | |
'textbox_placeholder': 'Enter text here', | |
'checkbox_label': 'Translate by sentence' | |
} | |
with open('app.js', mode='r', encoding='utf-8') as jsfp: | |
js = jsfp.read() | |
theme = ModifiedSeafoam() | |
css=""" | |
#output_image { | |
display: block !important; | |
margin-left: auto !important; | |
margin-right: auto !important; | |
max-width: 640px !important; | |
max-height: 640px !important; | |
overflow: hidden !important; | |
position: relative !important; | |
} | |
#output_image img { | |
object-fit: contain !important; | |
max-width: 100% !important; | |
max-height: 635px !important; | |
width: auto !important; | |
height: auto !important; | |
} | |
footer { | |
display:none !important | |
} | |
""" | |
with gr.Blocks(title=ui_text['title'], js=js, css=css, theme=theme) as demo: | |
# with gr.Blocks(title=ui_text['title'], theme=theme) as demo: | |
gr.Markdown(""" | |
# Title | |
### ミツアちゃんと楽しく画像生成🔖 | |
Demo for [Mitsua Likes model](https://huggingface.co/Mitsua/mitsua-likes) licensed under [Mitsua Likes Attribution Non-Commercial License](https://elanmitsua.notion.site/Mitsua-Likes-Attribution-NonCommercial-License-15baa85a9b278038be5dc7f47a9c26cc) | |
""", elem_id='title') | |
li = """ | |
- Description | |
- Notes | |
- Notes | |
- Notes | |
- Notes | |
- Disclaimer | |
- Copyright | |
""" | |
# Hidden textbox to store the language | |
lang_input = gr.Textbox(value='', elem_id='lang_input', visible=False, interactive=False) | |
with gr.Row(): | |
prompt = gr.Textbox( | |
label="Prompt", | |
lines=1, max_lines=3, | |
placeholder="Enter prompt here", | |
elem_id="text_input", | |
container=False, | |
min_width=500, | |
) | |
btn = gr.Button("Generate", elem_id='start_btn', variant="primary", size="lg", min_width=240, scale=0) | |
with gr.Accordion("Advanced Settings", open=False): | |
negative_prompt = gr.Textbox( | |
label="Negative Prompt", | |
lines=1, max_lines=3, | |
placeholder="Enter negative prompt here", | |
value="elan doodle, lowres", | |
elem_id="text_input", | |
# container=False, | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
styles = gr.Dropdown( | |
["sensei art / 先生アート", "digital illustration / デジタルイラスト", "analog illustration / アナログイラスト", "3d cg", "artworks / 芸術作品"], | |
value="digital illustration / デジタルイラスト", | |
label="Style", | |
) | |
with gr.Column(scale=1): | |
ar = gr.Dropdown( | |
["16:9", "4:3", "1:1", "3:4", "9:16"], | |
value="3:4", | |
label="Aspect Ratio", | |
) | |
with gr.Column(scale=1): | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=np.iinfo(np.int32).max, | |
step=1, | |
value=0, | |
) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
guidance_scale = gr.Slider( | |
label="Guidance Scale", | |
minimum=1, | |
maximum=10, | |
step=0.1, | |
value=5.0, | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
result = gr.Image(label="Result", show_label=False, elem_id="output_image", min_width=500) | |
with gr.Column(scale=1): | |
license_md = gr.Markdown("", elem_id='license_str') | |
character_similarities = gr.Label(label="Character Similarity Measure") | |
gr.Markdown(li, elem_id='description') | |
gr.on( | |
triggers=[btn.click, prompt.submit], | |
fn = infer, | |
inputs = [prompt, styles, lang_input, negative_prompt, seed, randomize_seed, ar, guidance_scale], | |
outputs = [result, license_md, character_similarities, seed], | |
js=""" | |
function(x, z, y, a, b, c, d, e){ | |
var userLang = navigator.language || navigator.userLanguage; | |
let langInputElement = document.querySelector("#lang_input textarea"); | |
return [x,z,userLang,a,b,c,d,e]; | |
} | |
""", | |
) | |
demo.queue().launch(show_api=False, share=False) |