Spaces:
Runtime error
Runtime error
File size: 3,753 Bytes
ad18a5b d1e787c 4cd9bad 384ec64 d1e787c 4cd9bad ad18a5b 4cd9bad ad18a5b 384ec64 ad18a5b 384ec64 ad18a5b 384ec64 d1e787c 9c4025a 384ec64 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import pathlib
import os
import gradio as gr
import torch
from diffusers import StableDiffusionPipeline
import utils
use_auth_token = os.environ["HF_AUTH_TOKEN"]
# Instantiate the pipeline.
device, revision, torch_dtype = (
("cuda", "fp16", torch.float16)
if torch.cuda.is_available()
else ("cpu", "main", torch.float32)
)
pipeline = StableDiffusionPipeline.from_pretrained(
pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4",
use_auth_token=use_auth_token,
revision=revision,
torch_dtype=torch_dtype,
).to(device)
# Load in the new concepts.
CONCEPT_PATH = pathlib.Path("learned_embeddings_dict.pt")
learned_embeddings_dict = torch.load(CONCEPT_PATH)
#
# concept_to_dummy_tokens_map = {}
# for concept_token, embedding_dict in learned_embeddings_dict.items():
# initializer_tokens = embedding_dict["initializer_tokens"]
# learned_embeddings = embedding_dict["learned_embeddings"]
# (
# initializer_ids,
# dummy_placeholder_ids,
# dummy_placeholder_tokens,
# ) = utils.add_new_tokens_to_tokenizer(
# concept_token=concept_token,
# initializer_tokens=initializer_tokens,
# tokenizer=pipeline.tokenizer,
# )
# pipeline.text_encoder.resize_token_embeddings(len(pipeline.tokenizer))
# token_embeddings = pipeline.text_encoder.get_input_embeddings().weight.data
# for d_id, tensor in zip(dummy_placeholder_ids, learned_embeddings):
# token_embeddings[d_id] = tensor
# concept_to_dummy_tokens_map[concept_token] = dummy_placeholder_tokens
#
#
# def replace_concept_tokens(text: str):
# for concept_token, dummy_tokens in concept_to_dummy_tokens_map.items():
# text = text.replace(concept_token, dummy_tokens)
# return text
# def inference(
# prompt: str, num_inference_steps: int = 50, guidance_scale: int = 3.0
# ):
# prompt = replace_concept_tokens(prompt)
# for _ in range(3):
# img_list = pipeline(
# prompt=prompt,
# num_inference_steps=num_inference_steps,
# guidance_scale=guidance_scale,
# )
# if not img_list["nsfw_content_detected"]:
# break
# return img_list["sample"]
DEFAULT_PROMPT = (
"A watercolor painting on textured paper of a <det-logo> using soft strokes,"
" pastel colors, incredible composition, masterpiece"
)
def white_imgs(prompt: str, guidance_scale: float, num_inference_steps: int, seed: int):
return [torch.ones(512, 512, 3).numpy() for _ in range(2)]
with gr.Blocks() as demo:
prompt = gr.Textbox(
label="Prompt including the token '<det-logo>'",
placeholder=DEFAULT_PROMPT,
interactive=True,
)
guidance_scale = gr.Slider(
minimum=1.0, maximum=10.0, value=3.0, label="Guidance Scale", interactive=True
)
num_inference_steps = gr.Slider(
minimum=25,
maximum=60,
value=40,
label="Num Inference Steps",
interactive=True,
step=1,
)
seed = gr.Slider(
minimum=2147483147,
maximum=2147483647,
value=2147483397,
label="Seed",
interactive=True,
)
output = gr.Textbox(
label="output", placeholder=use_auth_token[:10], interactive=False
)
gr.Button("test").click(
lambda s: replace_concept_tokens(s), inputs=[prompt], outputs=output
)
generate_btn = gr.Button(label="Generate")
gallery = gr.Gallery(
label="Generated Images",
value=[torch.zeros(512, 512, 3).numpy() for _ in range(2)],
).style(height="auto")
generate_btn.click(
white_imgs,
inputs=[prompt, guidance_scale, num_inference_steps, seed],
outputs=gallery,
)
demo.launch()
|