detsd_demo / app.py
Garrett Goon
cleanup
4811b12
raw
history blame
3.57 kB
import pathlib
import os
from PIL import Image
import gradio as gr
import torch
from diffusers import StableDiffusionPipeline
import utils
use_auth_token = os.environ["HF_AUTH_TOKEN"]
NSFW_IMAGE = Image.open("nsfw.png")
# Instantiate the pipeline.
device, revision, torch_dtype = (
("cuda", "fp16", torch.float16)
if torch.cuda.is_available()
else ("cpu", "main", torch.float32)
)
pipeline = StableDiffusionPipeline.from_pretrained(
pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4",
use_auth_token=use_auth_token,
revision=revision,
torch_dtype=torch_dtype,
).to(device)
# Load in the new concepts.
CONCEPT_PATH = pathlib.Path("learned_embeddings_dict.pt")
learned_embeddings_dict = torch.load(CONCEPT_PATH)
concept_to_dummy_tokens_map = {}
for concept_token, embedding_dict in learned_embeddings_dict.items():
initializer_tokens = embedding_dict["initializer_tokens"]
learned_embeddings = embedding_dict["learned_embeddings"]
(
initializer_ids,
dummy_placeholder_ids,
dummy_placeholder_tokens,
) = utils.add_new_tokens_to_tokenizer(
concept_token=concept_token,
initializer_tokens=initializer_tokens,
tokenizer=pipeline.tokenizer,
)
pipeline.text_encoder.resize_token_embeddings(len(pipeline.tokenizer))
token_embeddings = pipeline.text_encoder.get_input_embeddings().weight.data
for d_id, tensor in zip(dummy_placeholder_ids, learned_embeddings):
token_embeddings[d_id] = tensor
concept_to_dummy_tokens_map[concept_token] = dummy_placeholder_tokens
def replace_concept_tokens(text: str):
for concept_token, dummy_tokens in concept_to_dummy_tokens_map.items():
text = text.replace(concept_token, dummy_tokens)
return text
all_generated_images = []
def inference(
prompt: str, guidance_scale: int, num_inference_steps: int, seed: int
):
prompt = replace_concept_tokens(prompt)
generator = torch.Generator(device=device).manual_seed(seed)
output = pipeline(
prompt=[prompt] * 2,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
)
img_list, nsfw_list = output.images, output.nsfw_content_detected
for img, nsfw in zip(img_list, nsfw_list):
if nsfw:
all_generated_images.append(NSFW_IMAGE)
else:
all_generated_images.append(img)
return all_generated_images
DEFAULT_PROMPT = (
"A watercolor painting on textured paper of a <det-logo> using soft strokes,"
" pastel colors, incredible composition, masterpiece"
)
with gr.Blocks() as demo:
prompt = gr.Textbox(
label="Prompt including the token '<det-logo>'",
placeholder=DEFAULT_PROMPT,
interactive=True,
)
guidance_scale = gr.Slider(
minimum=1.0, maximum=20.0, value=3.0, label="Guidance Scale", interactive=True
)
num_inference_steps = gr.Slider(
minimum=25,
maximum=75,
value=40,
label="Num Inference Steps",
interactive=True,
step=1,
)
seed = gr.Slider(
minimum=2147483147,
maximum=2147483647,
label="Seed",
interactive=True,
randomize=True
)
generate_btn = gr.Button(value="Generate")
gallery = gr.Gallery(
label="Generated Images",
value=[],
).style(height="auto")
generate_btn.click(
inference,
inputs=[prompt, guidance_scale, num_inference_steps, seed],
outputs=gallery,
)
demo.launch()