Vishnu Sarukkai commited on
Commit
44549ff
·
1 Parent(s): bcf4913

First commit

Browse files
Files changed (1) hide show
  1. app.py +96 -0
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, EulerAncestralDiscreteScheduler
3
+ import torch
4
+ from diffusers.utils import load_image
5
+ import numpy as np
6
+ from controlnet_aux import HEDdetector
7
+ import PIL
8
+ from PIL import Image, ImageFilter
9
+ import matplotlib.pyplot as plt
10
+
11
+ negative_prompt = ""
12
+ device = 'cpu'
13
+ controlnet = ControlNetModel.from_pretrained("vsanimator/sketchcomplete", torch_dtype=torch.float16).to(device)
14
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
15
+ "runwayml/stable-diffusion-v1-5",
16
+ controlnet=controlnet, torch_dtype=torch.float16
17
+ ).to(device)
18
+ pipe.safety_checker = None
19
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
20
+ threshold = 250
21
+ hed = HEDdetector.from_pretrained('lllyasviel/ControlNet')
22
+
23
+ num_images = 3
24
+
25
+ with gr.Blocks() as demo:
26
+ start_state = []
27
+ for k in range(num_images):
28
+ start_state.append([None, None])
29
+ sketch_states = gr.State(start_state)
30
+ checkbox_state = gr.State(False)
31
+ with gr.Row():
32
+ with gr.Column(scale = 1):
33
+ with gr.Tabs(shape=(768, 768),min_width=512):
34
+ with gr.TabItem("Draw", shape=(512, 512),min_width=512):
35
+ i = gr.Image(source="canvas", shape=(512, 512), tool="color-sketch",
36
+ min_width=512, brush_radius = 2).style(width=600, height=600)
37
+ with gr.TabItem("ShadowDraw", shape=(512, 512),min_width=512):
38
+ i_sketch = gr.Image(shape=(512, 512),min_width=512).style(width=600, height=600)
39
+ prompt_box = gr.Textbox(label="Prompt")
40
+ with gr.Row():
41
+ btn = gr.Button("Render").style(width=100, height=80)
42
+ checkbox = gr.Checkbox(label = "ShadowDraw", value=False)
43
+ btn2 = gr.Button("Reset").style(width=100, height=80)
44
+ i_prev = gr.Image(shape=(512, 512),
45
+ min_width=512).style(width=768, height=768)
46
+ with gr.Column(scale = 1):
47
+ o_list = [gr.Image().style(width=512, height=512) for _ in range(num_images)]
48
+
49
+ def sketch(curr_sketch, prev_sketch, prompt, negative_prompt, seed, num_steps):
50
+ print("Sketching")
51
+ if curr_sketch is None:
52
+ return None, None
53
+ if prev_sketch is None:
54
+ prev_sketch = curr_sketch
55
+ generator = torch.Generator(device=device)
56
+ generator.manual_seed(seed)
57
+ curr_sketch_image = Image.fromarray(curr_sketch.astype(np.uint8)).convert("L")
58
+
59
+ # Run function call
60
+ images = pipe(prompt, curr_sketch_image.convert("RGB").point( lambda p: 256 if p > 128 else 0), negative_prompt = negative_prompt, num_inference_steps=num_steps, generator=generator, controlnet_conditioning_scale = 1.0).images
61
+
62
+ return images[0]
63
+
64
+ def run_sketching(prompt, curr_sketch, prev_sketch, sketch_states, shadow_draw):
65
+ to_return = []
66
+ for k in range(num_images):
67
+ seed = sketch_states[k][1]
68
+ if seed is None:
69
+ seed = np.random.randint(1000)
70
+ sketch_states[k][1] = seed
71
+ new_image = sketch(curr_sketch, prev_sketch, prompt,
72
+ negative_prompt, seed = seed, num_steps = 20)
73
+ to_return.append(new_image)
74
+ prev_sketch = curr_sketch
75
+ if shadow_draw:
76
+ hed_images = []
77
+ for image in to_return:
78
+ hed_images.append(hed(image, scribble=False))
79
+ avg_hed = np.mean([np.array(image) for image in hed_images], axis = 0)
80
+ curr_sketch = np.array(curr_sketch).astype(float) / 255.
81
+ curr_sketch = Image.fromarray(np.uint8(1.0*((0.0*curr_sketch + 1. - 1.*(avg_hed / 255.))) * 255.))
82
+ else:
83
+ curr_sketch = None
84
+ return to_return + [curr_sketch, prev_sketch, sketch_states]
85
+
86
+ def reset(sketch_states):
87
+ for k in range(num_images):
88
+ sketch_states[k] = [None, None]
89
+ return None, None, sketch_states
90
+
91
+ btn.click(run_sketching, [prompt_box, i, i_prev, sketch_states, checkbox_state], o_list + [i_sketch, i_prev, sketch_states])
92
+ btn2.click(reset, sketch_states, [i, i_prev, sketch_states])
93
+ checkbox.change(lambda i: i, inputs=[checkbox], outputs=[checkbox_state])
94
+
95
+ demo.launch()#share = True)
96
+