Spaces:
Runtime error
Runtime error
import pathlib | |
import os | |
from PIL import Image | |
import gradio as gr | |
import torch | |
from diffusers import StableDiffusionPipeline | |
import utils | |
use_auth_token = os.environ["HF_AUTH_TOKEN"] | |
NSFW_IMAGE = Image.open("nsfw.png") | |
# Instantiate the pipeline. | |
device, revision, torch_dtype = ( | |
("cuda", "fp16", torch.float16) | |
if torch.cuda.is_available() | |
else ("cpu", "main", torch.float32) | |
) | |
pipeline = StableDiffusionPipeline.from_pretrained( | |
pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4", | |
use_auth_token=use_auth_token, | |
revision=revision, | |
torch_dtype=torch_dtype, | |
).to(device) | |
# Load in the new concepts. | |
CONCEPT_PATH = pathlib.Path("learned_embeddings_dict.pt") | |
learned_embeddings_dict = torch.load(CONCEPT_PATH) | |
concept_to_dummy_tokens_map = {} | |
for concept_token, embedding_dict in learned_embeddings_dict.items(): | |
initializer_tokens = embedding_dict["initializer_tokens"] | |
learned_embeddings = embedding_dict["learned_embeddings"] | |
( | |
initializer_ids, | |
dummy_placeholder_ids, | |
dummy_placeholder_tokens, | |
) = utils.add_new_tokens_to_tokenizer( | |
concept_token=concept_token, | |
initializer_tokens=initializer_tokens, | |
tokenizer=pipeline.tokenizer, | |
) | |
pipeline.text_encoder.resize_token_embeddings(len(pipeline.tokenizer)) | |
token_embeddings = pipeline.text_encoder.get_input_embeddings().weight.data | |
for d_id, tensor in zip(dummy_placeholder_ids, learned_embeddings): | |
token_embeddings[d_id] = tensor | |
concept_to_dummy_tokens_map[concept_token] = dummy_placeholder_tokens | |
def replace_concept_tokens(text: str): | |
for concept_token, dummy_tokens in concept_to_dummy_tokens_map.items(): | |
text = text.replace(concept_token, dummy_tokens) | |
return text | |
all_generated_images = [] | |
def inference( | |
prompt: str, guidance_scale: int, num_inference_steps: int, seed: int | |
): | |
prompt = replace_concept_tokens(prompt) | |
generator = torch.Generator(device=device).manual_seed(seed) | |
output = pipeline( | |
prompt=[prompt] * 2, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
generator=generator, | |
) | |
img_list, nsfw_list = output.images, output.nsfw_content_detected | |
for img, nsfw in zip(img_list, nsfw_list): | |
if nsfw: | |
all_generated_images.append(NSFW_IMAGE) | |
else: | |
all_generated_images.append(img) | |
return all_generated_images | |
DEFAULT_PROMPT = ( | |
"A watercolor painting on textured paper of a <det-logo> using soft strokes," | |
" pastel colors, incredible composition, masterpiece" | |
) | |
with gr.Blocks() as demo: | |
prompt = gr.Textbox( | |
label="Prompt including the token '<det-logo>'", | |
placeholder=DEFAULT_PROMPT, | |
interactive=True, | |
) | |
guidance_scale = gr.Slider( | |
minimum=1.0, maximum=20.0, value=3.0, label="Guidance Scale", interactive=True | |
) | |
num_inference_steps = gr.Slider( | |
minimum=25, | |
maximum=75, | |
value=40, | |
label="Num Inference Steps", | |
interactive=True, | |
step=1, | |
) | |
seed = gr.Slider( | |
minimum=2147483147, | |
maximum=2147483647, | |
label="Seed", | |
interactive=True, | |
randomize=True | |
) | |
generate_btn = gr.Button(value="Generate") | |
gallery = gr.Gallery( | |
label="Generated Images", | |
value=[], | |
).style(height="auto") | |
generate_btn.click( | |
inference, | |
inputs=[prompt, guidance_scale, num_inference_steps, seed], | |
outputs=gallery, | |
) | |
demo.launch() | |