Spaces:
Runtime error
Runtime error
import pathlib | |
import os | |
import gradio as gr | |
import torch | |
from diffusers import StableDiffusionPipeline | |
import utils | |
use_auth_token = os.environ["HF_AUTH_TOKEN"] | |
# Instantiate the pipeline. | |
device, revision, torch_dtype = ( | |
("cuda", "fp16", torch.float16) | |
if torch.cuda.is_available() | |
else ("cpu", "main", torch.float32) | |
) | |
pipeline = StableDiffusionPipeline.from_pretrained( | |
pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4", | |
use_auth_token=use_auth_token, | |
revision=revision, | |
torch_dtype=torch_dtype, | |
).to(device) | |
# Load in the new concepts. | |
CONCEPT_PATH = pathlib.Path("learned_embeddings_dict.pt") | |
learned_embeddings_dict = torch.load(CONCEPT_PATH) | |
concept_to_dummy_tokens_map = {} | |
for concept_token, embedding_dict in learned_embeddings_dict.items(): | |
initializer_tokens = embedding_dict["initializer_tokens"] | |
learned_embeddings = embedding_dict["learned_embeddings"] | |
( | |
initializer_ids, | |
dummy_placeholder_ids, | |
dummy_placeholder_tokens, | |
) = utils.add_new_tokens_to_tokenizer( | |
concept_token=concept_token, | |
initializer_tokens=initializer_tokens, | |
tokenizer=pipeline.tokenizer, | |
) | |
pipeline.text_encoder.resize_token_embeddings(len(pipeline.tokenizer)) | |
token_embeddings = pipeline.text_encoder.get_input_embeddings().weight.data | |
for d_id, tensor in zip(dummy_placeholder_ids, learned_embeddings): | |
token_embeddings[d_id] = tensor | |
concept_to_dummy_tokens_map[concept_token] = dummy_placeholder_tokens | |
def replace_concept_tokens(text: str): | |
for concept_token, dummy_tokens in concept_to_dummy_tokens_map.items(): | |
text = text.replace(concept_token, dummy_tokens) | |
return text | |
def inference( | |
prompt: str, guidance_scale: int, num_inference_steps: int, seed: int | |
): | |
prompt = replace_concept_tokens(prompt) | |
generator = torch.Generator(device=device).manual_seed(seed) | |
out = pipeline( | |
prompt=[prompt] * 2, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
generator=generator, | |
) | |
img_list = [item['sample'] for item in out] | |
return img_list | |
DEFAULT_PROMPT = ( | |
"A watercolor painting on textured paper of a <det-logo> using soft strokes," | |
" pastel colors, incredible composition, masterpiece" | |
) | |
with gr.Blocks() as demo: | |
prompt = gr.Textbox( | |
label="Prompt including the token '<det-logo>'", | |
placeholder=DEFAULT_PROMPT, | |
interactive=True, | |
) | |
guidance_scale = gr.Slider( | |
minimum=1.0, maximum=10.0, value=3.0, label="Guidance Scale", interactive=True | |
) | |
num_inference_steps = gr.Slider( | |
minimum=25, | |
maximum=60, | |
value=40, | |
label="Num Inference Steps", | |
interactive=True, | |
step=1, | |
) | |
seed = gr.Slider( | |
minimum=2147483147, | |
maximum=2147483647, | |
value=2147483397, | |
label="Seed", | |
interactive=True, | |
) | |
output = gr.Textbox( | |
label="output", placeholder=use_auth_token[:5], interactive=False | |
) | |
gr.Button("test").click( | |
lambda s: replace_concept_tokens(s), inputs=[prompt], outputs=output | |
) | |
generate_btn = gr.Button(label="Generate") | |
gallery = gr.Gallery( | |
label="Generated Images", | |
value=[torch.ones(512, 512, 3).numpy() for _ in range(2)], | |
).style(height="auto") | |
generate_btn.click( | |
inference, | |
inputs=[prompt, guidance_scale, num_inference_steps, seed], | |
outputs=gallery, | |
) | |
demo.launch() | |