Prgckwb
:memo: Change
59df77b
raw
history blame
6.42 kB
import dataclasses
import gradio as gr
import spaces
import torch
from PIL import Image
from diffusers import DiffusionPipeline
from diffusers.utils import make_image_grid
DIFFUSERS_MODEL_IDS = [
# SD Models
"stabilityai/stable-diffusion-3-medium-diffusers",
"stabilityai/stable-diffusion-xl-base-1.0",
"stabilityai/stable-diffusion-2-1",
"runwayml/stable-diffusion-v1-5",
# Other Models
"Prgckwb/trpfrog-diffusion",
]
EXTERNAL_MODEL_MAPPING = {
"Beautiful Realistic Asians": "checkpoints/diffusers/Beautiful Realistic Asians v7",
}
MODEL_CHOICES = DIFFUSERS_MODEL_IDS + list(EXTERNAL_MODEL_MAPPING.keys())
current_model_id = "stabilityai/stable-diffusion-3-medium-diffusers"
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = None
@dataclasses.dataclass
class Input:
prompt: str
model_id: str = "stabilityai/stable-diffusion-3-medium-diffusers"
negative_prompt: str = ''
width: int = 1024
height: int = 1024
guidance_scale: float = 7.5
num_inference_step: int = 28
num_images: int = 4
use_safety_checker: bool = True
use_model_offload: bool = False
def to_list(self):
return [
self.prompt, self.model_id, self.negative_prompt,
self.width, self.height, self.guidance_scale,
self.num_inference_step, self.num_images, self.use_safety_checker, self.use_model_offload,
]
EXAMPLES = [
Input(prompt='A cat holding a sign that says Hello world').to_list(),
Input(
prompt='Beautiful pixel art of a Wizard with hovering text "Achivement unlocked: Diffusion models can spell now"'
).to_list(),
Input(prompt='A corgi wearing sunglasses says "U-Net is OVER!!"').to_list(),
Input(
prompt='Cinematic Photo of a beautiful korean fashion model bokeh train',
model_id='Beautiful Realistic Asians',
negative_prompt='(worst_quality:2.0) (MajicNegative_V2:0.8) BadNegAnatomyV1-neg bradhands cartoon, cgi, render, illustration, painting, drawing',
width=512,
height=512,
guidance_scale=5.0,
num_inference_step=50,
).to_list()
]
@spaces.GPU(duration=120)
@torch.inference_mode()
def inference(
prompt: str,
model_id: str = "stabilityai/stable-diffusion-3-medium-diffusers",
negative_prompt: str = "",
width: int = 512,
height: int = 512,
guidance_scale: float = 7.5,
num_inference_steps: int = 50,
num_images: int = 4,
safety_checker: bool = True,
use_model_offload:bool = False,
seed: int = 8888,
progress=gr.Progress(track_tqdm=True),
) -> Image.Image:
progress(0, "Starting inference...")
global current_model_id, pipe
progress(0.1, 'Loading pipeline...')
if model_id not in DIFFUSERS_MODEL_IDS:
model_id = EXTERNAL_MODEL_MAPPING[model_id]
pipe = DiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16,
)
current_model_id = model_id
if not safety_checker:
pipe.safety_checker = None
if model_id not in DIFFUSERS_MODEL_IDS:
progress(0.3, 'Loading Textual Inversion...')
# Load Textual Inversion
pipe.load_textual_inversion(
"checkpoints/embeddings/BadNegAnatomyV1 neg.pt", token='BadNegAnatomyV1-neg'
)
# Generation
progress(0.4, 'Generating images...')
if use_model_offload:
pipe.enable_model_cpu_offload()
else:
pipe = pipe.to('cuda')
generator = torch.Generator(device=device).manual_seed(seed)
images = pipe(
prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
num_images_per_prompt=num_images,
generator=generator,
).images
if num_images % 2 == 1:
image = make_image_grid(images, rows=num_images, cols=1)
else:
image = make_image_grid(images, rows=2, cols=num_images // 2)
return image
if __name__ == "__main__":
theme = gr.themes.Default(primary_hue=gr.themes.colors.emerald)
with gr.Blocks(theme=theme) as demo:
gr.Markdown(f"# Stable Diffusion Demo")
with gr.Row():
with gr.Column():
prompt = gr.Text(label="Prompt", placeholder="Enter a prompt here")
model_id = gr.Dropdown(
label="Model ID",
choices=MODEL_CHOICES,
value="stabilityai/stable-diffusion-3-medium-diffusers",
)
# Additional Input Settings
with gr.Accordion("Additional Settings", open=False):
negative_prompt = gr.Text(label="Negative Prompt", value="", )
with gr.Row():
width = gr.Number(label="Width", value=512, step=64, minimum=64, maximum=2048)
height = gr.Number(label="Height", value=512, step=64, minimum=64, maximum=2048)
num_images = gr.Number(label="Num Images", value=4, minimum=1, maximum=10, step=1)
seed = gr.Number(label="Seed", value=8888, step=1)
guidance_scale = gr.Slider(label="Guidance Scale", value=7.5, step=0.5, minimum=0, maximum=10)
num_inference_step = gr.Slider(
label="Num Inference Steps", value=50, minimum=1, maximum=100, step=2
)
with gr.Row():
use_safety_checker = gr.Checkbox(value=True, label='Use Safety Checker')
use_model_offload = gr.Checkbox(value=False, label='Use Model Offload')
with gr.Column():
output_image = gr.Image(label="Image", type="pil")
inputs = [
prompt,
model_id,
negative_prompt,
width,
height,
guidance_scale,
num_inference_step,
num_images,
use_safety_checker,
use_model_offload,
seed,
]
btn = gr.Button("Generate")
btn.click(
fn=inference,
inputs=inputs,
outputs=output_image
)
gr.Examples(
examples=EXAMPLES,
inputs=inputs,
)
demo.queue().launch()