import gradio as gr import torch # pipeline = StableDiffusionPipeline.from_pretrained( # pretrained_model_name_or_path="weights", torch_dtype=torch.float16 # ) # pipeline.to('cuda') concept_to_dummy_tokens_map = torch.load("concept_to_dummy_tokens_map.pt") def replace_concept_tokens(text: str): for concept_token, dummy_tokens in concept_to_dummy_tokens_map.items(): text = text.replace(concept_token, dummy_tokens) return text # def inference( # prompt: str, num_inference_steps: int = 50, guidance_scale: int = 3.0 # ): # prompt = replace_concept_tokens(prompt) # for _ in range(3): # img_list = pipeline( # prompt=prompt, # num_inference_steps=num_inference_steps, # guidance_scale=guidance_scale, # ) # if not img_list["nsfw_content_detected"]: # break # return img_list["sample"] DEFAULT_PROMPT = ( "A watercolor painting on textured paper of a using soft strokes," " pastel colors, incredible composition, masterpiece" ) def white_imgs(prompt: str, guidance_scale: float, num_inference_steps: int, seed: int): return [torch.ones(512, 512, 3).numpy() for _ in range(2)] with gr.Blocks() as demo: prompt = gr.Textbox( label="Prompt including the token ''", placeholder=DEFAULT_PROMPT, interactive=True, ) guidance_scale = gr.Slider( minimum=1.0, maximum=10.0, value=3.0, label="Guidance Scale", interactive=True ) num_inference_steps = gr.Slider( minimum=25, maximum=60, value=40, label="Num Inference Steps", interactive=True, step=1, ) seed = gr.Slider( minimum=2147483147, maximum=2147483647, value=2147483397, label="Seed", interactive=True, ) output = gr.Textbox(label="output", placeholder="output", interactive=False) gr.Button("test").click( lambda s: replace_concept_tokens(s), inputs=[prompt], outputs=output ) generate_btn = gr.Button(label="Generate") gallery = gr.Gallery( label="Generated Images", value=[torch.zeros(512, 512, 3).numpy() for _ in range(2)], ).style(height="auto") generate_btn.click( white_imgs, inputs=[prompt, guidance_scale, num_inference_steps, seed], outputs=gallery, ) demo.launch()