zoanhy commited on
Commit
fc5e70f
·
verified ·
1 Parent(s): 7573cc1

Upload 16 files

Browse files
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import gradio as gr
6
+ import torch
7
+
8
+ torch.jit.script = lambda f: f
9
+ import spaces
10
+
11
+ from app_canny import create_demo as create_demo_canny
12
+ from app_depth import create_demo as create_demo_depth
13
+ from model import Model
14
+ from settings import ALLOW_CHANGING_BASE_MODEL, DEFAULT_MODEL_ID, SHOW_DUPLICATE_BUTTON
15
+ from transformers.utils.hub import move_cache
16
+
17
+ move_cache()
18
+
19
+ DESCRIPTION = "ControlNet"
20
+
21
+ if not torch.cuda.is_available():
22
+ DESCRIPTION += "\n<p>Running on CPU.</p>"
23
+
24
+ model = Model(base_model_id=DEFAULT_MODEL_ID, task_name="Canny")
25
+
26
+
27
+ with gr.Blocks(css="style.css") as demo:
28
+ gr.Markdown(DESCRIPTION)
29
+ gr.DuplicateButton(
30
+ value="Duplicate Space for private use",
31
+ elem_id="duplicate-button",
32
+ visible=SHOW_DUPLICATE_BUTTON,
33
+ )
34
+
35
+ with gr.Tabs():
36
+ with gr.TabItem("Depth"):
37
+ create_demo_depth(model.process_depth)
38
+ with gr.TabItem("Canny"):
39
+ create_demo_canny(model.process_canny)
40
+
41
+ with gr.Accordion(label="Base model", open=False):
42
+ with gr.Row():
43
+ with gr.Column(scale=5):
44
+ current_base_model = gr.Text(label="Current base model")
45
+ with gr.Column(scale=1):
46
+ check_base_model_button = gr.Button("Check current base model")
47
+ with gr.Row():
48
+ with gr.Column(scale=5):
49
+ new_base_model_id = gr.Text(
50
+ label="New base model",
51
+ max_lines=1,
52
+ placeholder="runwayml/stable-diffusion-v1-5",
53
+ info="The base model must be compatible with Stable Diffusion v1.5.",
54
+ interactive=ALLOW_CHANGING_BASE_MODEL,
55
+ )
56
+ with gr.Column(scale=1):
57
+ change_base_model_button = gr.Button(
58
+ "Change base model", interactive=ALLOW_CHANGING_BASE_MODEL
59
+ )
60
+ if not ALLOW_CHANGING_BASE_MODEL:
61
+ gr.Markdown(
62
+ """The base model is not allowed to be changed in this Space so as not to slow down the demo, but it can be changed if you duplicate the Space."""
63
+ )
64
+
65
+ check_base_model_button.click(
66
+ fn=lambda: model.base_model_id,
67
+ outputs=current_base_model,
68
+ queue=False,
69
+ api_name="check_base_model",
70
+ )
71
+ gr.on(
72
+ triggers=[new_base_model_id.submit, change_base_model_button.click],
73
+ fn=model.set_base_model,
74
+ inputs=new_base_model_id,
75
+ outputs=current_base_model,
76
+ api_name=False,
77
+ )
78
+
79
+ if __name__ == "__main__":
80
+ demo.queue(max_size=20).launch()
app_canny.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import gradio as gr
4
+
5
+ from settings import (
6
+ DEFAULT_IMAGE_RESOLUTION,
7
+ DEFAULT_NUM_IMAGES,
8
+ MAX_IMAGE_RESOLUTION,
9
+ MAX_NUM_IMAGES,
10
+ MAX_SEED,
11
+ )
12
+ from utils import randomize_seed_fn
13
+
14
+
15
+ def create_demo(process):
16
+ with gr.Blocks() as demo:
17
+ with gr.Row():
18
+ with gr.Column():
19
+ image = gr.Image()
20
+ prompt = gr.Textbox(label="Prompt")
21
+ run_button = gr.Button("Run")
22
+ with gr.Accordion("Advanced options", open=False):
23
+ num_samples = gr.Slider(
24
+ label="Number of images",
25
+ minimum=1,
26
+ maximum=MAX_NUM_IMAGES,
27
+ value=DEFAULT_NUM_IMAGES,
28
+ step=1,
29
+ )
30
+ image_resolution = gr.Slider(
31
+ label="Image resolution",
32
+ minimum=256,
33
+ maximum=MAX_IMAGE_RESOLUTION,
34
+ value=DEFAULT_IMAGE_RESOLUTION,
35
+ step=256,
36
+ )
37
+ canny_low_threshold = gr.Slider(
38
+ label="Canny low threshold",
39
+ minimum=0,
40
+ maximum=1.0,
41
+ value=0.1,
42
+ step=0.05,
43
+ )
44
+ canny_high_threshold = gr.Slider(
45
+ label="Canny high threshold",
46
+ minimum=0,
47
+ maximum=1.0,
48
+ value=0.2,
49
+ step=0.05,
50
+ )
51
+ num_steps = gr.Slider(
52
+ label="Number of steps",
53
+ minimum=1,
54
+ maximum=100,
55
+ value=20,
56
+ step=1,
57
+ )
58
+ guidance_scale = gr.Slider(
59
+ label="Guidance scale",
60
+ minimum=0.1,
61
+ maximum=30.0,
62
+ value=7.5,
63
+ step=0.1,
64
+ )
65
+ seed = gr.Slider(
66
+ label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0
67
+ )
68
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
69
+ a_prompt = gr.Textbox(
70
+ label="Additional prompt",
71
+ value="high-quality, extremely detailed, 4K",
72
+ )
73
+ n_prompt = gr.Textbox(
74
+ label="Negative prompt",
75
+ value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
76
+ )
77
+ with gr.Column():
78
+ result = gr.Gallery(
79
+ label="Output", show_label=False, columns=2, object_fit="scale-down"
80
+ )
81
+
82
+ inputs = [
83
+ image,
84
+ prompt,
85
+ a_prompt,
86
+ n_prompt,
87
+ num_samples,
88
+ image_resolution,
89
+ num_steps,
90
+ guidance_scale,
91
+ seed,
92
+ canny_low_threshold,
93
+ canny_high_threshold,
94
+ ]
95
+ prompt.submit(
96
+ fn=randomize_seed_fn,
97
+ inputs=[seed, randomize_seed],
98
+ outputs=seed,
99
+ queue=False,
100
+ api_name=False,
101
+ ).then(
102
+ fn=process,
103
+ inputs=inputs,
104
+ outputs=result,
105
+ api_name=False,
106
+ )
107
+ run_button.click(
108
+ fn=randomize_seed_fn,
109
+ inputs=[seed, randomize_seed],
110
+ outputs=seed,
111
+ queue=False,
112
+ api_name=False,
113
+ ).then(
114
+ fn=process,
115
+ inputs=inputs,
116
+ outputs=result,
117
+ api_name="canny",
118
+ )
119
+ return demo
120
+
121
+
122
+ if __name__ == "__main__":
123
+ from model import Model
124
+
125
+ model = Model(task_name="Canny")
126
+ demo = create_demo(model.process_canny)
127
+ demo.queue().launch()
app_depth.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import gradio as gr
4
+
5
+ from settings import (
6
+ DEFAULT_IMAGE_RESOLUTION,
7
+ DEFAULT_NUM_IMAGES,
8
+ MAX_IMAGE_RESOLUTION,
9
+ MAX_NUM_IMAGES,
10
+ MAX_SEED,
11
+ )
12
+ from utils import randomize_seed_fn
13
+
14
+
15
+ def create_demo(process):
16
+ with gr.Blocks() as demo:
17
+ with gr.Row():
18
+ with gr.Column():
19
+ image = gr.Image()
20
+ prompt = gr.Textbox(label="Prompt")
21
+ run_button = gr.Button("Run")
22
+ with gr.Accordion("Advanced options", open=False):
23
+ preprocessor_name = gr.Radio(
24
+ label="Preprocessor",
25
+ choices=["Midas", "DPT", "None"],
26
+ type="value",
27
+ value="DPT",
28
+ )
29
+ num_samples = gr.Slider(
30
+ label="Number of images",
31
+ minimum=1,
32
+ maximum=MAX_NUM_IMAGES,
33
+ value=DEFAULT_NUM_IMAGES,
34
+ step=1,
35
+ )
36
+ image_resolution = gr.Slider(
37
+ label="Image resolution",
38
+ minimum=256,
39
+ maximum=MAX_IMAGE_RESOLUTION,
40
+ value=DEFAULT_IMAGE_RESOLUTION,
41
+ step=256,
42
+ )
43
+ preprocess_resolution = gr.Slider(
44
+ label="Preprocess resolution",
45
+ minimum=128,
46
+ maximum=512,
47
+ value=384,
48
+ step=1,
49
+ )
50
+ num_steps = gr.Slider(
51
+ label="Number of steps",
52
+ minimum=1,
53
+ maximum=100,
54
+ value=20,
55
+ step=1,
56
+ )
57
+ guidance_scale = gr.Slider(
58
+ label="Guidance scale",
59
+ minimum=0.1,
60
+ maximum=30.0,
61
+ value=7.5,
62
+ step=0.1,
63
+ )
64
+ seed = gr.Slider(
65
+ label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0
66
+ )
67
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
68
+ a_prompt = gr.Textbox(
69
+ label="Additional prompt",
70
+ value="high-quality, extremely detailed, 4K",
71
+ )
72
+ n_prompt = gr.Textbox(
73
+ label="Negative prompt",
74
+ value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
75
+ )
76
+ with gr.Column():
77
+ result = gr.Gallery(
78
+ label="Output", show_label=False, columns=2, object_fit="scale-down"
79
+ )
80
+
81
+ inputs = [
82
+ image,
83
+ prompt,
84
+ a_prompt,
85
+ n_prompt,
86
+ num_samples,
87
+ image_resolution,
88
+ preprocess_resolution,
89
+ num_steps,
90
+ guidance_scale,
91
+ seed,
92
+ preprocessor_name,
93
+ ]
94
+ prompt.submit(
95
+ fn=randomize_seed_fn,
96
+ inputs=[seed, randomize_seed],
97
+ outputs=seed,
98
+ queue=False,
99
+ api_name=False,
100
+ ).then(
101
+ fn=process,
102
+ inputs=inputs,
103
+ outputs=result,
104
+ api_name=False,
105
+ )
106
+ run_button.click(
107
+ fn=randomize_seed_fn,
108
+ inputs=[seed, randomize_seed],
109
+ outputs=seed,
110
+ queue=False,
111
+ api_name=False,
112
+ ).then(
113
+ fn=process,
114
+ inputs=inputs,
115
+ outputs=result,
116
+ api_name="depth",
117
+ )
118
+ return demo
119
+
120
+
121
+ if __name__ == "__main__":
122
+ from model import Model
123
+
124
+ model = Model(task_name="depth")
125
+ demo = create_demo(model.process_depth)
126
+ demo.queue().launch()
checkpoints/canny/controlnet/config.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ControlNetModel",
3
+ "_diffusers_version": "0.26.3",
4
+ "_name_or_path": "work_dirs/finetune/MultiGen20M_canny/ft_controlnet_sd15_canny_res512_bs256_lr1e-5_warmup100_iter5k_fp16ft0-1000/checkpoint-5000",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": null,
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": null,
9
+ "attention_head_dim": 8,
10
+ "block_out_channels": [
11
+ 320,
12
+ 640,
13
+ 1280,
14
+ 1280
15
+ ],
16
+ "class_embed_type": null,
17
+ "conditioning_channels": 3,
18
+ "conditioning_embedding_out_channels": [
19
+ 16,
20
+ 32,
21
+ 96,
22
+ 256
23
+ ],
24
+ "controlnet_conditioning_channel_order": "rgb",
25
+ "cross_attention_dim": 768,
26
+ "down_block_types": [
27
+ "CrossAttnDownBlock2D",
28
+ "CrossAttnDownBlock2D",
29
+ "CrossAttnDownBlock2D",
30
+ "DownBlock2D"
31
+ ],
32
+ "downsample_padding": 1,
33
+ "encoder_hid_dim": null,
34
+ "encoder_hid_dim_type": null,
35
+ "flip_sin_to_cos": true,
36
+ "freq_shift": 0,
37
+ "global_pool_conditions": false,
38
+ "in_channels": 4,
39
+ "layers_per_block": 2,
40
+ "mid_block_scale_factor": 1,
41
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
42
+ "norm_eps": 1e-05,
43
+ "norm_num_groups": 32,
44
+ "num_attention_heads": null,
45
+ "num_class_embeds": null,
46
+ "only_cross_attention": false,
47
+ "projection_class_embeddings_input_dim": null,
48
+ "resnet_time_scale_shift": "default",
49
+ "transformer_layers_per_block": 1,
50
+ "upcast_attention": false,
51
+ "use_linear_projection": false
52
+ }
checkpoints/canny/controlnet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3fd425077e65024addc5cf73c97195fcfd499b7a5e16868e4251b47cebb0d89
3
+ size 1445157120
checkpoints/depth/controlnet/config.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ControlNetModel",
3
+ "_diffusers_version": "0.26.3",
4
+ "_name_or_path": "work_dirs/finetune/MultiGen20M_depth/ft_controlnet_sd15_depth_res512_bs256_lr1e-5_warmup100_iter5k_fp16ft0-200/checkpoint-5000",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": null,
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": null,
9
+ "attention_head_dim": 8,
10
+ "block_out_channels": [
11
+ 320,
12
+ 640,
13
+ 1280,
14
+ 1280
15
+ ],
16
+ "class_embed_type": null,
17
+ "conditioning_channels": 3,
18
+ "conditioning_embedding_out_channels": [
19
+ 16,
20
+ 32,
21
+ 96,
22
+ 256
23
+ ],
24
+ "controlnet_conditioning_channel_order": "rgb",
25
+ "cross_attention_dim": 768,
26
+ "down_block_types": [
27
+ "CrossAttnDownBlock2D",
28
+ "CrossAttnDownBlock2D",
29
+ "CrossAttnDownBlock2D",
30
+ "DownBlock2D"
31
+ ],
32
+ "downsample_padding": 1,
33
+ "encoder_hid_dim": null,
34
+ "encoder_hid_dim_type": null,
35
+ "flip_sin_to_cos": true,
36
+ "freq_shift": 0,
37
+ "global_pool_conditions": false,
38
+ "in_channels": 4,
39
+ "layers_per_block": 2,
40
+ "mid_block_scale_factor": 1,
41
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
42
+ "norm_eps": 1e-05,
43
+ "norm_num_groups": 32,
44
+ "num_attention_heads": null,
45
+ "num_class_embeds": null,
46
+ "only_cross_attention": false,
47
+ "projection_class_embeddings_input_dim": null,
48
+ "resnet_time_scale_shift": "default",
49
+ "transformer_layers_per_block": 1,
50
+ "upcast_attention": false,
51
+ "use_linear_projection": false
52
+ }
checkpoints/depth/controlnet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7450404d13ef888c9701433a3c17b2a86c021a6d042f9f5d2519602abd7f2f3
3
+ size 1445157120
cv_utils.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+
5
+ def resize_image(input_image, resolution, interpolation=None):
6
+ H, W, C = input_image.shape
7
+ H = float(H)
8
+ W = float(W)
9
+ k = float(resolution) / max(H, W)
10
+ H *= k
11
+ W *= k
12
+ H = int(np.round(H / 64.0)) * 64
13
+ W = int(np.round(W / 64.0)) * 64
14
+ if interpolation is None:
15
+ interpolation = cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA
16
+ img = cv2.resize(input_image, (W, H), interpolation=interpolation)
17
+ return img
depth_estimator.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import PIL.Image
3
+ from controlnet_aux.util import HWC3
4
+ from transformers import pipeline
5
+
6
+ from cv_utils import resize_image
7
+
8
+
9
+ class DepthEstimator:
10
+ def __init__(self):
11
+ self.model = pipeline("depth-estimation")
12
+
13
+ def __call__(self, image: np.ndarray, **kwargs) -> PIL.Image.Image:
14
+ detect_resolution = kwargs.pop("detect_resolution", 512)
15
+ image_resolution = kwargs.pop("image_resolution", 512)
16
+ image = np.array(image)
17
+ image = HWC3(image)
18
+ image = resize_image(image, resolution=detect_resolution)
19
+ image = PIL.Image.fromarray(image)
20
+ image = self.model(image)
21
+ image = image["depth"]
22
+ image = np.array(image)
23
+ image = HWC3(image)
24
+ image = resize_image(image, resolution=image_resolution)
25
+ return PIL.Image.fromarray(image)
image_segmentor.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import PIL.Image
4
+ import torch
5
+ from controlnet_aux.util import HWC3, ade_palette
6
+ from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
7
+
8
+ from cv_utils import resize_image
9
+
10
+
11
+ class ImageSegmentor:
12
+ def __init__(self):
13
+ self.image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
14
+ self.image_segmentor = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small")
15
+
16
+ @torch.inference_mode()
17
+ def __call__(self, image: np.ndarray, **kwargs) -> PIL.Image.Image:
18
+ detect_resolution = kwargs.pop("detect_resolution", 512)
19
+ image_resolution = kwargs.pop("image_resolution", 512)
20
+ image = HWC3(image)
21
+ image = resize_image(image, resolution=detect_resolution)
22
+ image = PIL.Image.fromarray(image)
23
+
24
+ pixel_values = self.image_processor(image, return_tensors="pt").pixel_values
25
+ outputs = self.image_segmentor(pixel_values)
26
+ seg = self.image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
27
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
28
+ for label, color in enumerate(ade_palette()):
29
+ color_seg[seg == label, :] = color
30
+ color_seg = color_seg.astype(np.uint8)
31
+
32
+ color_seg = resize_image(color_seg, resolution=image_resolution, interpolation=cv2.INTER_NEAREST)
33
+ return PIL.Image.fromarray(color_seg)
model.py ADDED
@@ -0,0 +1,705 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import gc
4
+
5
+ import numpy as np
6
+ import PIL.Image
7
+ import torch
8
+ from controlnet_aux.util import HWC3
9
+ from diffusers import (
10
+ ControlNetModel,
11
+ DiffusionPipeline,
12
+ StableDiffusionControlNetPipeline,
13
+ UniPCMultistepScheduler,
14
+ )
15
+
16
+ from cv_utils import resize_image
17
+ from preprocessor import Preprocessor
18
+ from settings import MAX_IMAGE_RESOLUTION, MAX_NUM_IMAGES
19
+
20
+ CONTROLNET_MODEL_IDS = {
21
+ "Canny": "checkpoints/canny/controlnet",
22
+
23
+ "depth": "checkpoints/depth/controlnet",
24
+ }
25
+
26
+
27
+ def download_all_controlnet_weights() -> None:
28
+ for model_id in CONTROLNET_MODEL_IDS.values():
29
+ ControlNetModel.from_pretrained(model_id)
30
+
31
+
32
+ class Model:
33
+ def __init__(self, base_model_id: str = "runwayml/stable-diffusion-v1-5", task_name: str = "Canny"):
34
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
35
+ self.base_model_id = ""
36
+ self.task_name = ""
37
+ self.pipe = self.load_pipe(base_model_id, task_name)
38
+ self.preprocessor = Preprocessor()
39
+
40
+ def load_pipe(self, base_model_id: str, task_name) -> DiffusionPipeline:
41
+ if (
42
+ base_model_id == self.base_model_id
43
+ and task_name == self.task_name
44
+ and hasattr(self, "pipe")
45
+ and self.pipe is not None
46
+ ):
47
+ return self.pipe
48
+ model_id = CONTROLNET_MODEL_IDS[task_name]
49
+ controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
50
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
51
+ base_model_id, safety_checker=None, controlnet=controlnet, torch_dtype=torch.float16
52
+ )
53
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
54
+ if self.device.type == "cuda":
55
+ pipe.disable_xformers_memory_efficient_attention()
56
+ pipe.to(self.device)
57
+ torch.cuda.empty_cache()
58
+ gc.collect()
59
+ self.base_model_id = base_model_id
60
+ self.task_name = task_name
61
+ return pipe
62
+
63
+ def set_base_model(self, base_model_id: str) -> str:
64
+ if not base_model_id or base_model_id == self.base_model_id:
65
+ return self.base_model_id
66
+ del self.pipe
67
+ torch.cuda.empty_cache()
68
+ gc.collect()
69
+ try:
70
+ self.pipe = self.load_pipe(base_model_id, self.task_name)
71
+ except Exception:
72
+ self.pipe = self.load_pipe(self.base_model_id, self.task_name)
73
+ return self.base_model_id
74
+
75
+ def load_controlnet_weight(self, task_name: str) -> None:
76
+ if task_name == self.task_name:
77
+ return
78
+ if self.pipe is not None and hasattr(self.pipe, "controlnet"):
79
+ del self.pipe.controlnet
80
+ torch.cuda.empty_cache()
81
+ gc.collect()
82
+ model_id = CONTROLNET_MODEL_IDS[task_name]
83
+ controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
84
+ controlnet.to(self.device)
85
+ torch.cuda.empty_cache()
86
+ gc.collect()
87
+ self.pipe.controlnet = controlnet
88
+ self.task_name = task_name
89
+
90
+ def get_prompt(self, prompt: str, additional_prompt: str) -> str:
91
+ if not prompt:
92
+ prompt = additional_prompt
93
+ else:
94
+ prompt = f"{prompt}, {additional_prompt}"
95
+ return prompt
96
+
97
+ @torch.autocast("cuda")
98
+ def run_pipe(
99
+ self,
100
+ prompt: str,
101
+ negative_prompt: str,
102
+ control_image: PIL.Image.Image,
103
+ num_images: int,
104
+ num_steps: int,
105
+ guidance_scale: float,
106
+ seed: int,
107
+ ) -> list[PIL.Image.Image]:
108
+ generator = torch.Generator().manual_seed(seed)
109
+ return self.pipe(
110
+ prompt=prompt,
111
+ negative_prompt=negative_prompt,
112
+ guidance_scale=guidance_scale,
113
+ num_images_per_prompt=num_images,
114
+ num_inference_steps=num_steps,
115
+ generator=generator,
116
+ image=control_image,
117
+ ).images
118
+
119
+ @torch.inference_mode()
120
+ def process_canny(
121
+ self,
122
+ image: np.ndarray,
123
+ prompt: str,
124
+ additional_prompt: str,
125
+ negative_prompt: str,
126
+ num_images: int,
127
+ image_resolution: int,
128
+ num_steps: int,
129
+ guidance_scale: float,
130
+ seed: int,
131
+ low_threshold: int,
132
+ high_threshold: int,
133
+ ) -> list[PIL.Image.Image]:
134
+ if image is None:
135
+ raise ValueError
136
+ if image_resolution > MAX_IMAGE_RESOLUTION:
137
+ raise ValueError
138
+ if num_images > MAX_NUM_IMAGES:
139
+ raise ValueError
140
+
141
+ self.preprocessor.load("Canny")
142
+ control_image = self.preprocessor(
143
+ image=image, low_threshold=low_threshold, high_threshold=high_threshold, detect_resolution=image_resolution
144
+ )
145
+
146
+ self.load_controlnet_weight("Canny")
147
+ results = self.run_pipe(
148
+ prompt=self.get_prompt(prompt, additional_prompt),
149
+ negative_prompt=negative_prompt,
150
+ control_image=control_image,
151
+ num_images=num_images,
152
+ num_steps=num_steps,
153
+ guidance_scale=guidance_scale,
154
+ seed=seed,
155
+ )
156
+ conditions_of_generated_imgs = [
157
+ self.preprocessor(
158
+ image=x, low_threshold=low_threshold, high_threshold=high_threshold, detect_resolution=image_resolution
159
+ ) for x in results
160
+ ]
161
+ return [control_image] * num_images + results + conditions_of_generated_imgs
162
+
163
+ @torch.inference_mode()
164
+ def process_mlsd(
165
+ self,
166
+ image: np.ndarray,
167
+ prompt: str,
168
+ additional_prompt: str,
169
+ negative_prompt: str,
170
+ num_images: int,
171
+ image_resolution: int,
172
+ preprocess_resolution: int,
173
+ num_steps: int,
174
+ guidance_scale: float,
175
+ seed: int,
176
+ value_threshold: float,
177
+ distance_threshold: float,
178
+ ) -> list[PIL.Image.Image]:
179
+ if image is None:
180
+ raise ValueError
181
+ if image_resolution > MAX_IMAGE_RESOLUTION:
182
+ raise ValueError
183
+ if num_images > MAX_NUM_IMAGES:
184
+ raise ValueError
185
+
186
+ self.preprocessor.load("MLSD")
187
+ control_image = self.preprocessor(
188
+ image=image,
189
+ image_resolution=image_resolution,
190
+ detect_resolution=preprocess_resolution,
191
+ thr_v=value_threshold,
192
+ thr_d=distance_threshold,
193
+ )
194
+ self.load_controlnet_weight("MLSD")
195
+ results = self.run_pipe(
196
+ prompt=self.get_prompt(prompt, additional_prompt),
197
+ negative_prompt=negative_prompt,
198
+ control_image=control_image,
199
+ num_images=num_images,
200
+ num_steps=num_steps,
201
+ guidance_scale=guidance_scale,
202
+ seed=seed,
203
+ )
204
+ return [control_image] + results
205
+
206
+ @torch.inference_mode()
207
+ def process_scribble(
208
+ self,
209
+ image: np.ndarray,
210
+ prompt: str,
211
+ additional_prompt: str,
212
+ negative_prompt: str,
213
+ num_images: int,
214
+ image_resolution: int,
215
+ preprocess_resolution: int,
216
+ num_steps: int,
217
+ guidance_scale: float,
218
+ seed: int,
219
+ preprocessor_name: str,
220
+ ) -> list[PIL.Image.Image]:
221
+ if image is None:
222
+ raise ValueError
223
+ if image_resolution > MAX_IMAGE_RESOLUTION:
224
+ raise ValueError
225
+ if num_images > MAX_NUM_IMAGES:
226
+ raise ValueError
227
+
228
+ if preprocessor_name == "None":
229
+ image = HWC3(image)
230
+ image = resize_image(image, resolution=image_resolution)
231
+ control_image = PIL.Image.fromarray(image)
232
+ elif preprocessor_name == "HED":
233
+ self.preprocessor.load(preprocessor_name)
234
+ control_image = self.preprocessor(
235
+ image=image,
236
+ image_resolution=image_resolution,
237
+ detect_resolution=preprocess_resolution,
238
+ scribble=False,
239
+ )
240
+ elif preprocessor_name == "PidiNet":
241
+ self.preprocessor.load(preprocessor_name)
242
+ control_image = self.preprocessor(
243
+ image=image,
244
+ image_resolution=image_resolution,
245
+ detect_resolution=preprocess_resolution,
246
+ safe=False,
247
+ )
248
+ self.load_controlnet_weight("scribble")
249
+ results = self.run_pipe(
250
+ prompt=self.get_prompt(prompt, additional_prompt),
251
+ negative_prompt=negative_prompt,
252
+ control_image=control_image,
253
+ num_images=num_images,
254
+ num_steps=num_steps,
255
+ guidance_scale=guidance_scale,
256
+ seed=seed,
257
+ )
258
+ return [control_image] + results
259
+
260
+ @torch.inference_mode()
261
+ def process_scribble_interactive(
262
+ self,
263
+ image_and_mask: dict[str, np.ndarray],
264
+ prompt: str,
265
+ additional_prompt: str,
266
+ negative_prompt: str,
267
+ num_images: int,
268
+ image_resolution: int,
269
+ num_steps: int,
270
+ guidance_scale: float,
271
+ seed: int,
272
+ ) -> list[PIL.Image.Image]:
273
+ if image_and_mask is None:
274
+ raise ValueError
275
+ if image_resolution > MAX_IMAGE_RESOLUTION:
276
+ raise ValueError
277
+ if num_images > MAX_NUM_IMAGES:
278
+ raise ValueError
279
+
280
+ image = image_and_mask["mask"]
281
+ image = HWC3(image)
282
+ image = resize_image(image, resolution=image_resolution)
283
+ control_image = PIL.Image.fromarray(image)
284
+
285
+ self.load_controlnet_weight("scribble")
286
+ results = self.run_pipe(
287
+ prompt=self.get_prompt(prompt, additional_prompt),
288
+ negative_prompt=negative_prompt,
289
+ control_image=control_image,
290
+ num_images=num_images,
291
+ num_steps=num_steps,
292
+ guidance_scale=guidance_scale,
293
+ seed=seed,
294
+ )
295
+ return [control_image] + results
296
+
297
+ @torch.inference_mode()
298
+ def process_softedge(
299
+ self,
300
+ image: np.ndarray,
301
+ prompt: str,
302
+ additional_prompt: str,
303
+ negative_prompt: str,
304
+ num_images: int,
305
+ image_resolution: int,
306
+ preprocess_resolution: int,
307
+ num_steps: int,
308
+ guidance_scale: float,
309
+ seed: int,
310
+ preprocessor_name: str,
311
+ ) -> list[PIL.Image.Image]:
312
+ if image is None:
313
+ raise ValueError
314
+ if image_resolution > MAX_IMAGE_RESOLUTION:
315
+ raise ValueError
316
+ if num_images > MAX_NUM_IMAGES:
317
+ raise ValueError
318
+
319
+ if preprocessor_name == "None":
320
+ image = HWC3(image)
321
+ image = resize_image(image, resolution=image_resolution)
322
+ control_image = PIL.Image.fromarray(image)
323
+ elif preprocessor_name in ["HED", "HED safe"]:
324
+ safe = "safe" in preprocessor_name
325
+ self.preprocessor.load("HED")
326
+ control_image = self.preprocessor(
327
+ image=image,
328
+ image_resolution=image_resolution,
329
+ detect_resolution=preprocess_resolution,
330
+ scribble=safe,
331
+ )
332
+ elif preprocessor_name in ["PidiNet", "PidiNet safe"]:
333
+ safe = "safe" in preprocessor_name
334
+ self.preprocessor.load("PidiNet")
335
+ control_image = self.preprocessor(
336
+ image=image,
337
+ image_resolution=image_resolution,
338
+ detect_resolution=preprocess_resolution,
339
+ safe=safe,
340
+ )
341
+ else:
342
+ raise ValueError
343
+ self.load_controlnet_weight("softedge")
344
+ results = self.run_pipe(
345
+ prompt=self.get_prompt(prompt, additional_prompt),
346
+ negative_prompt=negative_prompt,
347
+ control_image=control_image,
348
+ num_images=num_images,
349
+ num_steps=num_steps,
350
+ guidance_scale=guidance_scale,
351
+ seed=seed,
352
+ )
353
+ conditions_of_generated_imgs = [
354
+ self.preprocessor(
355
+ image=x,
356
+ image_resolution=image_resolution,
357
+ detect_resolution=preprocess_resolution,
358
+ scribble=safe,
359
+ ) for x in results
360
+ ]
361
+ return [control_image] * num_images + results + conditions_of_generated_imgs
362
+
363
+ @torch.inference_mode()
364
+ def process_openpose(
365
+ self,
366
+ image: np.ndarray,
367
+ prompt: str,
368
+ additional_prompt: str,
369
+ negative_prompt: str,
370
+ num_images: int,
371
+ image_resolution: int,
372
+ preprocess_resolution: int,
373
+ num_steps: int,
374
+ guidance_scale: float,
375
+ seed: int,
376
+ preprocessor_name: str,
377
+ ) -> list[PIL.Image.Image]:
378
+ if image is None:
379
+ raise ValueError
380
+ if image_resolution > MAX_IMAGE_RESOLUTION:
381
+ raise ValueError
382
+ if num_images > MAX_NUM_IMAGES:
383
+ raise ValueError
384
+
385
+ if preprocessor_name == "None":
386
+ image = HWC3(image)
387
+ image = resize_image(image, resolution=image_resolution)
388
+ control_image = PIL.Image.fromarray(image)
389
+ else:
390
+ self.preprocessor.load("Openpose")
391
+ control_image = self.preprocessor(
392
+ image=image,
393
+ image_resolution=image_resolution,
394
+ detect_resolution=preprocess_resolution,
395
+ hand_and_face=True,
396
+ )
397
+ self.load_controlnet_weight("Openpose")
398
+ results = self.run_pipe(
399
+ prompt=self.get_prompt(prompt, additional_prompt),
400
+ negative_prompt=negative_prompt,
401
+ control_image=control_image,
402
+ num_images=num_images,
403
+ num_steps=num_steps,
404
+ guidance_scale=guidance_scale,
405
+ seed=seed,
406
+ )
407
+ return [control_image] + results
408
+
409
+ @torch.inference_mode()
410
+ def process_segmentation(
411
+ self,
412
+ image: np.ndarray,
413
+ prompt: str,
414
+ additional_prompt: str,
415
+ negative_prompt: str,
416
+ num_images: int,
417
+ image_resolution: int,
418
+ preprocess_resolution: int,
419
+ num_steps: int,
420
+ guidance_scale: float,
421
+ seed: int,
422
+ preprocessor_name: str,
423
+ ) -> list[PIL.Image.Image]:
424
+ if image is None:
425
+ raise ValueError
426
+ if image_resolution > MAX_IMAGE_RESOLUTION:
427
+ raise ValueError
428
+ if num_images > MAX_NUM_IMAGES:
429
+ raise ValueError
430
+
431
+ if preprocessor_name == "None":
432
+ image = HWC3(image)
433
+ image = resize_image(image, resolution=image_resolution)
434
+ control_image = PIL.Image.fromarray(image)
435
+ else:
436
+ self.preprocessor.load(preprocessor_name)
437
+ control_image = self.preprocessor(
438
+ image=image,
439
+ image_resolution=image_resolution,
440
+ detect_resolution=preprocess_resolution,
441
+ )
442
+ self.load_controlnet_weight("segmentation")
443
+ results = self.run_pipe(
444
+ prompt=self.get_prompt(prompt, additional_prompt),
445
+ negative_prompt=negative_prompt,
446
+ control_image=control_image,
447
+ num_images=num_images,
448
+ num_steps=num_steps,
449
+ guidance_scale=guidance_scale,
450
+ seed=seed,
451
+ )
452
+ self.preprocessor.load('UPerNet')
453
+ conditions_of_generated_imgs = [
454
+ self.preprocessor(
455
+ image=np.array(x),
456
+ image_resolution=image_resolution,
457
+ detect_resolution=preprocess_resolution,
458
+ ) for x in results
459
+ ]
460
+ return [control_image] * num_images + results + conditions_of_generated_imgs
461
+
462
+ @torch.inference_mode()
463
+ def process_depth(
464
+ self,
465
+ image: np.ndarray,
466
+ prompt: str,
467
+ additional_prompt: str,
468
+ negative_prompt: str,
469
+ num_images: int,
470
+ image_resolution: int,
471
+ preprocess_resolution: int,
472
+ num_steps: int,
473
+ guidance_scale: float,
474
+ seed: int,
475
+ preprocessor_name: str,
476
+ ) -> list[PIL.Image.Image]:
477
+ if image is None:
478
+ raise ValueError
479
+ if image_resolution > MAX_IMAGE_RESOLUTION:
480
+ raise ValueError
481
+ if num_images > MAX_NUM_IMAGES:
482
+ raise ValueError
483
+
484
+ if preprocessor_name == "None":
485
+ image = HWC3(image)
486
+ image = resize_image(image, resolution=image_resolution)
487
+ control_image = PIL.Image.fromarray(image)
488
+ else:
489
+ self.preprocessor.load(preprocessor_name)
490
+ control_image = self.preprocessor(
491
+ image=image,
492
+ image_resolution=image_resolution,
493
+ detect_resolution=preprocess_resolution,
494
+ )
495
+ self.load_controlnet_weight("depth")
496
+ results = self.run_pipe(
497
+ prompt=self.get_prompt(prompt, additional_prompt),
498
+ negative_prompt=negative_prompt,
499
+ control_image=control_image,
500
+ num_images=num_images,
501
+ num_steps=num_steps,
502
+ guidance_scale=guidance_scale,
503
+ seed=seed,
504
+ )
505
+ conditions_of_generated_imgs = [
506
+ self.preprocessor(
507
+ image=x,
508
+ image_resolution=image_resolution,
509
+ detect_resolution=preprocess_resolution,
510
+ ) for x in results
511
+ ]
512
+ return [control_image] * num_images + results + conditions_of_generated_imgs
513
+
514
+ @torch.inference_mode()
515
+ def process_normal(
516
+ self,
517
+ image: np.ndarray,
518
+ prompt: str,
519
+ additional_prompt: str,
520
+ negative_prompt: str,
521
+ num_images: int,
522
+ image_resolution: int,
523
+ preprocess_resolution: int,
524
+ num_steps: int,
525
+ guidance_scale: float,
526
+ seed: int,
527
+ preprocessor_name: str,
528
+ ) -> list[PIL.Image.Image]:
529
+ if image is None:
530
+ raise ValueError
531
+ if image_resolution > MAX_IMAGE_RESOLUTION:
532
+ raise ValueError
533
+ if num_images > MAX_NUM_IMAGES:
534
+ raise ValueError
535
+
536
+ if preprocessor_name == "None":
537
+ image = HWC3(image)
538
+ image = resize_image(image, resolution=image_resolution)
539
+ control_image = PIL.Image.fromarray(image)
540
+ else:
541
+ self.preprocessor.load("NormalBae")
542
+ control_image = self.preprocessor(
543
+ image=image,
544
+ image_resolution=image_resolution,
545
+ detect_resolution=preprocess_resolution,
546
+ )
547
+ self.load_controlnet_weight("NormalBae")
548
+ results = self.run_pipe(
549
+ prompt=self.get_prompt(prompt, additional_prompt),
550
+ negative_prompt=negative_prompt,
551
+ control_image=control_image,
552
+ num_images=num_images,
553
+ num_steps=num_steps,
554
+ guidance_scale=guidance_scale,
555
+ seed=seed,
556
+ )
557
+ return [control_image] + results
558
+
559
+ @torch.inference_mode()
560
+ def process_lineart(
561
+ self,
562
+ image: np.ndarray,
563
+ prompt: str,
564
+ additional_prompt: str,
565
+ negative_prompt: str,
566
+ num_images: int,
567
+ image_resolution: int,
568
+ preprocess_resolution: int,
569
+ num_steps: int,
570
+ guidance_scale: float,
571
+ seed: int,
572
+ preprocessor_name: str,
573
+ ) -> list[PIL.Image.Image]:
574
+ if image is None:
575
+ raise ValueError
576
+ if image_resolution > MAX_IMAGE_RESOLUTION:
577
+ raise ValueError
578
+ if num_images > MAX_NUM_IMAGES:
579
+ raise ValueError
580
+
581
+ if preprocessor_name in ["None", "None (anime)"]:
582
+ image = 255 - HWC3(image)
583
+ image = resize_image(image, resolution=image_resolution)
584
+ control_image = PIL.Image.fromarray(image)
585
+ elif preprocessor_name in ["Lineart", "Lineart coarse"]:
586
+ coarse = "coarse" in preprocessor_name
587
+ self.preprocessor.load("Lineart")
588
+ control_image = self.preprocessor(
589
+ image=image,
590
+ image_resolution=image_resolution,
591
+ detect_resolution=preprocess_resolution,
592
+ coarse=coarse,
593
+ )
594
+ elif preprocessor_name == "Lineart (anime)":
595
+ self.preprocessor.load("LineartAnime")
596
+ control_image = self.preprocessor(
597
+ image=image,
598
+ image_resolution=image_resolution,
599
+ detect_resolution=preprocess_resolution,
600
+ )
601
+ # NOTE: We still use the general lineart model
602
+ if "anime" in preprocessor_name:
603
+ self.load_controlnet_weight("lineart_anime")
604
+ else:
605
+ self.load_controlnet_weight("lineart")
606
+ results = self.run_pipe(
607
+ prompt=self.get_prompt(prompt, additional_prompt),
608
+ negative_prompt=negative_prompt,
609
+ control_image=control_image,
610
+ num_images=num_images,
611
+ num_steps=num_steps,
612
+ guidance_scale=guidance_scale,
613
+ seed=seed,
614
+ )
615
+ self.preprocessor.load("Lineart")
616
+ conditions_of_generated_imgs = [
617
+ self.preprocessor(
618
+ image=x,
619
+ image_resolution=image_resolution,
620
+ detect_resolution=preprocess_resolution,
621
+ ) for x in results
622
+ ]
623
+
624
+ control_image = PIL.Image.fromarray((255 - np.array(control_image)).astype(np.uint8))
625
+ conditions_of_generated_imgs = [PIL.Image.fromarray((255 - np.array(x)).astype(np.uint8)) for x in conditions_of_generated_imgs]
626
+
627
+ return [control_image] * num_images + results + conditions_of_generated_imgs
628
+
629
+ @torch.inference_mode()
630
+ def process_shuffle(
631
+ self,
632
+ image: np.ndarray,
633
+ prompt: str,
634
+ additional_prompt: str,
635
+ negative_prompt: str,
636
+ num_images: int,
637
+ image_resolution: int,
638
+ num_steps: int,
639
+ guidance_scale: float,
640
+ seed: int,
641
+ preprocessor_name: str,
642
+ ) -> list[PIL.Image.Image]:
643
+ if image is None:
644
+ raise ValueError
645
+ if image_resolution > MAX_IMAGE_RESOLUTION:
646
+ raise ValueError
647
+ if num_images > MAX_NUM_IMAGES:
648
+ raise ValueError
649
+
650
+ if preprocessor_name == "None":
651
+ image = HWC3(image)
652
+ image = resize_image(image, resolution=image_resolution)
653
+ control_image = PIL.Image.fromarray(image)
654
+ else:
655
+ self.preprocessor.load(preprocessor_name)
656
+ control_image = self.preprocessor(
657
+ image=image,
658
+ image_resolution=image_resolution,
659
+ )
660
+ self.load_controlnet_weight("shuffle")
661
+ results = self.run_pipe(
662
+ prompt=self.get_prompt(prompt, additional_prompt),
663
+ negative_prompt=negative_prompt,
664
+ control_image=control_image,
665
+ num_images=num_images,
666
+ num_steps=num_steps,
667
+ guidance_scale=guidance_scale,
668
+ seed=seed,
669
+ )
670
+ return [control_image] + results
671
+
672
+ @torch.inference_mode()
673
+ def process_ip2p(
674
+ self,
675
+ image: np.ndarray,
676
+ prompt: str,
677
+ additional_prompt: str,
678
+ negative_prompt: str,
679
+ num_images: int,
680
+ image_resolution: int,
681
+ num_steps: int,
682
+ guidance_scale: float,
683
+ seed: int,
684
+ ) -> list[PIL.Image.Image]:
685
+ if image is None:
686
+ raise ValueError
687
+ if image_resolution > MAX_IMAGE_RESOLUTION:
688
+ raise ValueError
689
+ if num_images > MAX_NUM_IMAGES:
690
+ raise ValueError
691
+
692
+ image = HWC3(image)
693
+ image = resize_image(image, resolution=image_resolution)
694
+ control_image = PIL.Image.fromarray(image)
695
+ self.load_controlnet_weight("ip2p")
696
+ results = self.run_pipe(
697
+ prompt=self.get_prompt(prompt, additional_prompt),
698
+ negative_prompt=negative_prompt,
699
+ control_image=control_image,
700
+ num_images=num_images,
701
+ num_steps=num_steps,
702
+ guidance_scale=guidance_scale,
703
+ seed=seed,
704
+ )
705
+ return [control_image] + results
preprocessor.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+
3
+ import numpy as np
4
+ import PIL.Image
5
+ import torch
6
+ import torchvision
7
+ from controlnet_aux import (
8
+ CannyDetector,
9
+ ContentShuffleDetector,
10
+ HEDdetector,
11
+ LineartAnimeDetector,
12
+ LineartDetector,
13
+ MidasDetector,
14
+ MLSDdetector,
15
+ NormalBaeDetector,
16
+ OpenposeDetector,
17
+ PidiNetDetector,
18
+ )
19
+ from controlnet_aux.util import HWC3
20
+
21
+ from cv_utils import resize_image
22
+ from depth_estimator import DepthEstimator
23
+ from image_segmentor import ImageSegmentor
24
+
25
+ from kornia.core import Tensor
26
+ from kornia.filters import canny
27
+
28
+
29
+ class Canny:
30
+
31
+ def __call__(
32
+ self,
33
+ images: np.array,
34
+ low_threshold: float = 0.1,
35
+ high_threshold: float = 0.2,
36
+ kernel_size: tuple[int, int] | int = (5, 5),
37
+ sigma: tuple[float, float] | Tensor = (1, 1),
38
+ hysteresis: bool = True,
39
+ eps: float = 1e-6
40
+ ) -> torch.Tensor:
41
+
42
+ assert low_threshold is not None, "low_threshold must be provided"
43
+ assert high_threshold is not None, "high_threshold must be provided"
44
+
45
+ images = torch.from_numpy(images).permute(2, 0, 1).unsqueeze(0) / 255.0
46
+
47
+ images_tensor = canny(images, low_threshold, high_threshold, kernel_size, sigma, hysteresis, eps)[1]
48
+ images_tensor = (images_tensor[0][0].numpy() * 255).astype(np.uint8)
49
+ return images_tensor
50
+
51
+
52
+ class Preprocessor:
53
+ MODEL_ID = "lllyasviel/Annotators"
54
+
55
+ def __init__(self):
56
+ self.model = None
57
+ self.name = ""
58
+
59
+ def load(self, name: str) -> None:
60
+ if name == self.name:
61
+ return
62
+ if name == "Canny":
63
+ self.model = Canny()
64
+ elif name == "DPT":
65
+ self.model = DepthEstimator()
66
+ else:
67
+ raise ValueError
68
+ torch.cuda.empty_cache()
69
+ gc.collect()
70
+ self.name = name
71
+
72
+ def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image:
73
+ if self.name == "Canny":
74
+ if "detect_resolution" in kwargs:
75
+ detect_resolution = kwargs.pop("detect_resolution")
76
+ image = np.array(image)
77
+ image = HWC3(image)
78
+ image = resize_image(image, resolution=detect_resolution)
79
+ image = self.model(image, **kwargs)
80
+ return PIL.Image.fromarray(image).convert('RGB')
81
+ elif self.name == "Midas":
82
+ detect_resolution = kwargs.pop("detect_resolution", 512)
83
+ image_resolution = kwargs.pop("image_resolution", 512)
84
+ image = np.array(image)
85
+ image = HWC3(image)
86
+ image = resize_image(image, resolution=detect_resolution)
87
+ image = self.model(image, **kwargs)
88
+ image = HWC3(image)
89
+ image = resize_image(image, resolution=image_resolution)
90
+ return PIL.Image.fromarray(image)
91
+ else:
92
+ return self.model(image, **kwargs)
requirements.txt ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ accelerate==0.21.0
3
+ aiofiles==23.2.1
4
+ altair==5.3.0
5
+ annotated-types==0.6.0
6
+ anyio==4.3.0
7
+ attrs==23.2.0
8
+ certifi==2024.2.2
9
+ cffi==1.16.0
10
+ charset-normalizer==3.3.2
11
+ click==8.1.7
12
+ cmake==3.29.2
13
+ contourpy==1.2.1
14
+ controlnet-aux==0.0.6
15
+ cycler==0.12.1
16
+ diffusers==0.26.3
17
+ einops==0.6.1
18
+ exceptiongroup==1.2.0
19
+ fastapi==0.110.1
20
+ ffmpy==0.3.2
21
+ filelock==3.13.4
22
+ flatbuffers==24.3.25
23
+ fonttools==4.51.0
24
+ fsspec==2024.3.1
25
+ gradio==4.26.0
26
+ gradio_client==0.15.1
27
+ h11==0.14.0
28
+ httpcore==1.0.5
29
+ httpx==0.27.0
30
+ huggingface-hub==0.21.2
31
+ idna==3.7
32
+ imageio==2.34.0
33
+ importlib_metadata==7.1.0
34
+ importlib_resources==6.4.0
35
+ Jinja2==3.1.3
36
+ jsonschema==4.21.1
37
+ jsonschema-specifications==2023.12.1
38
+ kiwisolver==1.4.5
39
+ kornia==0.7.0
40
+ lazy_loader==0.4
41
+ lit==18.1.3
42
+ markdown-it-py==3.0.0
43
+ MarkupSafe==2.1.5
44
+ matplotlib==3.8.4
45
+ mdurl==0.1.2
46
+ mediapipe==0.10.1
47
+ mpmath==1.3.0
48
+ mypy-extensions==1.0.0
49
+ networkx==3.3
50
+ numpy==1.26.4
51
+ nvidia-cublas-cu11==11.10.3.66
52
+ nvidia-cuda-cupti-cu11==11.7.101
53
+ nvidia-cuda-nvrtc-cu11==11.7.99
54
+ nvidia-cuda-runtime-cu11==11.7.99
55
+ nvidia-cudnn-cu11==8.5.0.96
56
+ nvidia-cufft-cu11==10.9.0.58
57
+ nvidia-curand-cu11==10.2.10.91
58
+ nvidia-cusolver-cu11==11.4.0.1
59
+ nvidia-cusparse-cu11==11.7.4.91
60
+ nvidia-nccl-cu11==2.14.3
61
+ nvidia-nvtx-cu11==11.7.91
62
+ opencv-contrib-python==4.9.0.80
63
+ opencv-python==4.9.0.80
64
+ opencv-python-headless==4.8.0.74
65
+ orjson==3.10.0
66
+ packaging==24.0
67
+ pandas==2.2.2
68
+ pillow==10.3.0
69
+ protobuf==3.20.3
70
+ psutil==5.9.8
71
+ pycparser==2.22
72
+ pydantic==2.7.0
73
+ pydantic_core==2.18.1
74
+ pydub==0.25.1
75
+ Pygments==2.17.2
76
+ pyparsing==3.1.2
77
+ pyre-extensions==0.0.29
78
+ python-dateutil==2.9.0.post0
79
+ python-multipart==0.0.9
80
+ pytz==2024.1
81
+ PyYAML==6.0.1
82
+ referencing==0.34.0
83
+ regex==2023.12.25
84
+ requests==2.31.0
85
+ rich==13.7.1
86
+ rpds-py==0.18.0
87
+ ruff==0.3.7
88
+ safetensors==0.4.1
89
+ scikit-image==0.23.1
90
+ scipy==1.13.0
91
+ semantic-version==2.10.0
92
+ shellingham==1.5.4
93
+ six==1.16.0
94
+ sniffio==1.3.1
95
+ sounddevice==0.4.6
96
+ spaces==0.26.0
97
+ starlette==0.37.2
98
+ sympy==1.12
99
+ tifffile==2024.2.12
100
+ timm==0.9.16
101
+ tokenizers==0.15.2
102
+ tomlkit==0.12.0
103
+ toolz==0.12.1
104
+ torch==2.0.1
105
+ torchvision==0.15.2
106
+ tqdm==4.66.2
107
+ transformers==4.38.1
108
+ triton==2.0.0
109
+ typer==0.12.3
110
+ typing-inspect==0.9.0
111
+ typing_extensions==4.11.0
112
+ tzdata==2024.1
113
+ urllib3==2.2.1
114
+ uvicorn==0.29.0
115
+ websockets==11.0.3
116
+ xformers==0.0.20
117
+ zipp==3.18.1
settings.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+
5
+ DEFAULT_MODEL_ID = os.getenv("DEFAULT_MODEL_ID", "runwayml/stable-diffusion-v1-5")
6
+
7
+ MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "4"))
8
+ DEFAULT_NUM_IMAGES = min(MAX_NUM_IMAGES, int(os.getenv("DEFAULT_NUM_IMAGES", "2")))
9
+ MAX_IMAGE_RESOLUTION = int(os.getenv("MAX_IMAGE_RESOLUTION", "768"))
10
+ DEFAULT_IMAGE_RESOLUTION = min(MAX_IMAGE_RESOLUTION, int(os.getenv("DEFAULT_IMAGE_RESOLUTION", "512")))
11
+
12
+ ALLOW_CHANGING_BASE_MODEL = os.getenv("SPACE_ID") != "hysts/ControlNet-v1-1"
13
+ SHOW_DUPLICATE_BUTTON = os.getenv("SHOW_DUPLICATE_BUTTON") == "1"
14
+
15
+ MAX_SEED = np.iinfo(np.int32).max
style.css ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
4
+
5
+ #duplicate-button {
6
+ margin: auto;
7
+ color: #fff;
8
+ background: #1565c0;
9
+ border-radius: 100vh;
10
+ }
utils.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ from settings import MAX_SEED
4
+
5
+
6
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
7
+ if randomize_seed:
8
+ seed = random.randint(0, MAX_SEED)
9
+ return seed