Spaces:
Runtime error
Runtime error
import pathlib | |
import os | |
import gradio as gr | |
import torch | |
from diffusers import StableDiffusionPipeline | |
import utils | |
use_auth_token = os.environ["HF_AUTH_TOKEN"] | |
# Instantiate the pipeline. | |
device, revision, torch_dtype = ( | |
("cuda", "fp16", torch.float16) | |
if torch.cuda.is_available() | |
else ("cpu", "main", torch.float32) | |
) | |
pipeline = StableDiffusionPipeline.from_pretrained( | |
pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4", | |
use_auth_token=use_auth_token, | |
revision=revision, | |
torch_dtype=torch_dtype, | |
).to(device) | |
# Load in the new concepts. | |
CONCEPT_PATH = pathlib.Path("learned_embeddings_dict.pt") | |
learned_embeddings_dict = torch.load(CONCEPT_PATH) | |
# | |
# concept_to_dummy_tokens_map = {} | |
# for concept_token, embedding_dict in learned_embeddings_dict.items(): | |
# initializer_tokens = embedding_dict["initializer_tokens"] | |
# learned_embeddings = embedding_dict["learned_embeddings"] | |
# ( | |
# initializer_ids, | |
# dummy_placeholder_ids, | |
# dummy_placeholder_tokens, | |
# ) = utils.add_new_tokens_to_tokenizer( | |
# concept_token=concept_token, | |
# initializer_tokens=initializer_tokens, | |
# tokenizer=pipeline.tokenizer, | |
# ) | |
# pipeline.text_encoder.resize_token_embeddings(len(pipeline.tokenizer)) | |
# token_embeddings = pipeline.text_encoder.get_input_embeddings().weight.data | |
# for d_id, tensor in zip(dummy_placeholder_ids, learned_embeddings): | |
# token_embeddings[d_id] = tensor | |
# concept_to_dummy_tokens_map[concept_token] = dummy_placeholder_tokens | |
# | |
# | |
# def replace_concept_tokens(text: str): | |
# for concept_token, dummy_tokens in concept_to_dummy_tokens_map.items(): | |
# text = text.replace(concept_token, dummy_tokens) | |
# return text | |
# def inference( | |
# prompt: str, num_inference_steps: int = 50, guidance_scale: int = 3.0 | |
# ): | |
# prompt = replace_concept_tokens(prompt) | |
# for _ in range(3): | |
# img_list = pipeline( | |
# prompt=prompt, | |
# num_inference_steps=num_inference_steps, | |
# guidance_scale=guidance_scale, | |
# ) | |
# if not img_list["nsfw_content_detected"]: | |
# break | |
# return img_list["sample"] | |
DEFAULT_PROMPT = ( | |
"A watercolor painting on textured paper of a <det-logo> using soft strokes," | |
" pastel colors, incredible composition, masterpiece" | |
) | |
def white_imgs(prompt: str, guidance_scale: float, num_inference_steps: int, seed: int): | |
return [torch.ones(512, 512, 3).numpy() for _ in range(2)] | |
with gr.Blocks() as demo: | |
prompt = gr.Textbox( | |
label="Prompt including the token '<det-logo>'", | |
placeholder=DEFAULT_PROMPT, | |
interactive=True, | |
) | |
guidance_scale = gr.Slider( | |
minimum=1.0, maximum=10.0, value=3.0, label="Guidance Scale", interactive=True | |
) | |
num_inference_steps = gr.Slider( | |
minimum=25, | |
maximum=60, | |
value=40, | |
label="Num Inference Steps", | |
interactive=True, | |
step=1, | |
) | |
seed = gr.Slider( | |
minimum=2147483147, | |
maximum=2147483647, | |
value=2147483397, | |
label="Seed", | |
interactive=True, | |
) | |
output = gr.Textbox( | |
label="output", placeholder=use_auth_token[:10], interactive=False | |
) | |
gr.Button("test").click( | |
lambda s: replace_concept_tokens(s), inputs=[prompt], outputs=output | |
) | |
generate_btn = gr.Button(label="Generate") | |
gallery = gr.Gallery( | |
label="Generated Images", | |
value=[torch.zeros(512, 512, 3).numpy() for _ in range(2)], | |
).style(height="auto") | |
generate_btn.click( | |
white_imgs, | |
inputs=[prompt, guidance_scale, num_inference_steps, seed], | |
outputs=gallery, | |
) | |
demo.launch() | |