File size: 1,727 Bytes
b8d3c8f
 
 
 
22175fc
 
 
 
 
 
 
b8d3c8f
efcf16b
22175fc
b8d3c8f
 
 
 
 
 
22175fc
b8d3c8f
 
22175fc
 
 
 
b8d3c8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aea07e3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import gradio as gr
import torch
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler

use_cuda = torch.cuda.is_available()

if use_cuda:
  dtype = torch.float16
else:
  dtype = torch.float32

controlnet = ControlNetModel.from_pretrained(
    "williamberman/controlnet-fill50k", 
    torch_dtype=dtype
)

pipe = StableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", 
    controlnet=controlnet,
    safety_checker=None,
    torch_dtype=dtype
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)

if use_cuda:
  pipe.enable_xformers_memory_efficient_attention()
  pipe.enable_model_cpu_offload()

def inference(prompt, image, seed=-1):
    if seed == -1:
        generator = None
    else:
        generator = torch.Generator().manual_seed(seed)

    image = pipe(prompt, image, num_inference_steps=20, generator=generator).images[0]

    return image

io = gr.Interface(
  inference,
  inputs = [
    gr.Textbox(lines=3, label="Prompt"),
    gr.Image(label="Controlnet conditioning", type="pil"),
    gr.Number(-1, label="Seed", precision=0),
  ],
  outputs=[
    gr.Image(type="pil"),
  ],
  examples=[
    ["red circle with blue background", "images/0.png", 0],
    ["cyan circle with brown floral background", "images/1.png", 0],
    ["light coral circle with white background", "images/2.png", 0],
    ["cornflower blue circle with light golden rod yellow background", "images/3.png", 0],
    ["light slate gray circle with blue background", "images/4.png", 0],
    ["light golden rod yellow circle with turquoise background", "images/5.png", 0],
  ],
  title="fill50k controlnet",
  cache_examples=True,
)
io.launch()