Spaces:
Runtime error
Runtime error
File size: 3,570 Bytes
ad18a5b d1e787c 4811b12 d1e787c 4cd9bad 384ec64 d1e787c 4cd9bad ad18a5b 4cd9bad ad18a5b 4811b12 384ec64 ad18a5b 384ec64 ad18a5b 4e26ea8 384ec64 4811b12 03168a3 4811b12 03168a3 4811b12 384ec64 4811b12 384ec64 4811b12 384ec64 4811b12 384ec64 9c4025a 4811b12 384ec64 eae426c 384ec64 03168a3 384ec64 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import pathlib
import os
from PIL import Image
import gradio as gr
import torch
from diffusers import StableDiffusionPipeline
import utils
use_auth_token = os.environ["HF_AUTH_TOKEN"]
NSFW_IMAGE = Image.open("nsfw.png")
# Instantiate the pipeline.
device, revision, torch_dtype = (
("cuda", "fp16", torch.float16)
if torch.cuda.is_available()
else ("cpu", "main", torch.float32)
)
pipeline = StableDiffusionPipeline.from_pretrained(
pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4",
use_auth_token=use_auth_token,
revision=revision,
torch_dtype=torch_dtype,
).to(device)
# Load in the new concepts.
CONCEPT_PATH = pathlib.Path("learned_embeddings_dict.pt")
learned_embeddings_dict = torch.load(CONCEPT_PATH)
concept_to_dummy_tokens_map = {}
for concept_token, embedding_dict in learned_embeddings_dict.items():
initializer_tokens = embedding_dict["initializer_tokens"]
learned_embeddings = embedding_dict["learned_embeddings"]
(
initializer_ids,
dummy_placeholder_ids,
dummy_placeholder_tokens,
) = utils.add_new_tokens_to_tokenizer(
concept_token=concept_token,
initializer_tokens=initializer_tokens,
tokenizer=pipeline.tokenizer,
)
pipeline.text_encoder.resize_token_embeddings(len(pipeline.tokenizer))
token_embeddings = pipeline.text_encoder.get_input_embeddings().weight.data
for d_id, tensor in zip(dummy_placeholder_ids, learned_embeddings):
token_embeddings[d_id] = tensor
concept_to_dummy_tokens_map[concept_token] = dummy_placeholder_tokens
def replace_concept_tokens(text: str):
for concept_token, dummy_tokens in concept_to_dummy_tokens_map.items():
text = text.replace(concept_token, dummy_tokens)
return text
all_generated_images = []
def inference(
prompt: str, guidance_scale: int, num_inference_steps: int, seed: int
):
prompt = replace_concept_tokens(prompt)
generator = torch.Generator(device=device).manual_seed(seed)
output = pipeline(
prompt=[prompt] * 2,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
)
img_list, nsfw_list = output.images, output.nsfw_content_detected
for img, nsfw in zip(img_list, nsfw_list):
if nsfw:
all_generated_images.append(NSFW_IMAGE)
else:
all_generated_images.append(img)
return all_generated_images
DEFAULT_PROMPT = (
"A watercolor painting on textured paper of a <det-logo> using soft strokes,"
" pastel colors, incredible composition, masterpiece"
)
with gr.Blocks() as demo:
prompt = gr.Textbox(
label="Prompt including the token '<det-logo>'",
placeholder=DEFAULT_PROMPT,
interactive=True,
)
guidance_scale = gr.Slider(
minimum=1.0, maximum=20.0, value=3.0, label="Guidance Scale", interactive=True
)
num_inference_steps = gr.Slider(
minimum=25,
maximum=75,
value=40,
label="Num Inference Steps",
interactive=True,
step=1,
)
seed = gr.Slider(
minimum=2147483147,
maximum=2147483647,
label="Seed",
interactive=True,
randomize=True
)
generate_btn = gr.Button(value="Generate")
gallery = gr.Gallery(
label="Generated Images",
value=[],
).style(height="auto")
generate_btn.click(
inference,
inputs=[prompt, guidance_scale, num_inference_steps, seed],
outputs=gallery,
)
demo.launch()
|