dvir-bria commited on
Commit
1babe47
·
verified ·
1 Parent(s): 4a09384

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +286 -0
model.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import gc
4
+
5
+ import numpy as np
6
+ import PIL.Image
7
+ import torch
8
+
9
+ from diffusers import (
10
+ ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL, EulerAncestralDiscreteScheduler
11
+ )
12
+
13
+ from cv_utils import resize_image
14
+ from preprocessor import Preprocessor
15
+ from settings import MAX_IMAGE_RESOLUTION, MAX_NUM_IMAGES
16
+
17
+ CONTROLNET_MODEL_IDS = {
18
+ "Canny": "briaai/BRIA-2.2-ControlNet-Canny",
19
+ "Depth": "briaai/BRIA-2.2-ControlNet-Depth",
20
+ "Recoloring": "briaai/BRIA-2.2-ControlNet-Recoloring",
21
+ }
22
+
23
+
24
+ def download_all_controlnet_weights() -> None:
25
+ for model_id in CONTROLNET_MODEL_IDS.values():
26
+ ControlNetModel.from_pretrained(model_id)
27
+
28
+
29
+ class Model:
30
+ def __init__(self, base_model_id: str = "briaai/BRIA-2.2", task_name: str = "Canny"):
31
+ self.device = torch.device("cuda:0")
32
+ self.base_model_id = ""
33
+ self.task_name = ""
34
+ self.pipe = self.load_pipe(base_model_id, task_name)
35
+ self.preprocessor = Preprocessor()
36
+
37
+ def load_pipe(self, base_model_id: str, task_name) -> DiffusionPipeline:
38
+ if (
39
+ base_model_id == self.base_model_id
40
+ and task_name == self.task_name
41
+ and hasattr(self, "pipe")
42
+ and self.pipe is not None
43
+ ):
44
+ return self.pipe
45
+ model_id = CONTROLNET_MODEL_IDS[task_name]
46
+ controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16).to('cuda')
47
+ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
48
+ base_model_id,
49
+ controlnet=controlnet,
50
+ torch_dtype=torch.float16,
51
+ device_map='auto',
52
+ low_cpu_mem_usage=True,
53
+ offload_state_dict=True,
54
+ ).to('cuda')
55
+ pipe.scheduler = EulerAncestralDiscreteScheduler(
56
+ beta_start=0.00085,
57
+ beta_end=0.012,
58
+ beta_schedule="scaled_linear",
59
+ num_train_timesteps=1000,
60
+ steps_offset=1
61
+ )
62
+ # pipe.enable_freeu(b1=1.1, b2=1.1, s1=0.5, s2=0.7)
63
+ pipe.enable_xformers_memory_efficient_attention()
64
+ pipe.force_zeros_for_empty_prompt = False
65
+
66
+ torch.cuda.empty_cache()
67
+ gc.collect()
68
+ self.base_model_id = base_model_id
69
+ self.task_name = task_name
70
+ return pipe
71
+
72
+ def set_base_model(self, base_model_id: str) -> str:
73
+ if not base_model_id or base_model_id == self.base_model_id:
74
+ return self.base_model_id
75
+ del self.pipe
76
+ torch.cuda.empty_cache()
77
+ gc.collect()
78
+ try:
79
+ self.pipe = self.load_pipe(base_model_id, self.task_name)
80
+ except Exception:
81
+ self.pipe = self.load_pipe(self.base_model_id, self.task_name)
82
+ return self.base_model_id
83
+
84
+ def load_controlnet_weight(self, task_name: str) -> None:
85
+ if task_name == self.task_name:
86
+ return
87
+ if self.pipe is not None and hasattr(self.pipe, "controlnet"):
88
+ del self.pipe.controlnet
89
+ torch.cuda.empty_cache()
90
+ gc.collect()
91
+ model_id = CONTROLNET_MODEL_IDS[task_name]
92
+ controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
93
+ controlnet.to(self.device)
94
+ torch.cuda.empty_cache()
95
+ gc.collect()
96
+ self.pipe.controlnet = controlnet
97
+ self.task_name = task_name
98
+
99
+ def get_prompt(self, prompt: str, additional_prompt: str) -> str:
100
+ if not prompt:
101
+ prompt = additional_prompt
102
+ else:
103
+ prompt = f"{prompt}, {additional_prompt}"
104
+ return prompt
105
+
106
+ @torch.autocast("cuda")
107
+ def run_pipe(
108
+ self,
109
+ prompt: str,
110
+ negative_prompt: str,
111
+ control_image: PIL.Image.Image,
112
+ num_images: int,
113
+ num_steps: int,
114
+ controlnet_conditioning_scale: float,
115
+ seed: int,
116
+ ) -> list[PIL.Image.Image]:
117
+ generator = torch.Generator().manual_seed(seed)
118
+ return self.pipe(
119
+ prompt=prompt,
120
+ negative_prompt=negative_prompt,
121
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
122
+ num_images_per_prompt=num_images,
123
+ num_inference_steps=num_steps,
124
+ generator=generator,
125
+ image=control_image,
126
+ ).images
127
+
128
+
129
+ def resize_image(image):
130
+ image = image.convert('RGB')
131
+ current_size = image.size
132
+ if current_size[0] > current_size[1]:
133
+ center_cropped_image = transforms.functional.center_crop(image, (current_size[1], current_size[1]))
134
+ else:
135
+ center_cropped_image = transforms.functional.center_crop(image, (current_size[0], current_size[0]))
136
+ resized_image = transforms.functional.resize(center_cropped_image, (1024, 1024))
137
+ return resized_image
138
+
139
+ def get_canny_filter(image):
140
+ low_threshold = 100
141
+ high_threshold = 200
142
+
143
+ if not isinstance(image, np.ndarray):
144
+ image = np.array(image)
145
+
146
+ image = cv2.Canny(image, low_threshold, high_threshold)
147
+ image = image[:, :, None]
148
+ image = np.concatenate([image, image, image], axis=2)
149
+ canny_image = Image.fromarray(image)
150
+ return canny_image
151
+
152
+
153
+
154
+ @torch.inference_mode()
155
+ def process_canny(
156
+ self,
157
+ image: np.ndarray,
158
+ prompt: str,
159
+ negative_prompt: str,
160
+ image_resolution: int,
161
+ num_steps: int,
162
+ controlnet_conditioning_scale: float,
163
+ seed: int,
164
+ ) -> list[PIL.Image.Image]:
165
+
166
+ # resize input_image to 1024x1024
167
+ input_image = resize_image(image)
168
+
169
+ canny_image = get_canny_filter(input_image)
170
+
171
+ self.load_controlnet_weight("Canny")
172
+ results = self.run_pipe(
173
+ prompt, negative_prompt=negative_prompt, image=canny_image, num_inference_steps=num_steps, controlnet_conditioning_scale=float(controlnet_conditioning_scale)
174
+ )
175
+ return [control_image] + results
176
+
177
+
178
+
179
+
180
+
181
+
182
+
183
+ ----------------------------------------------------------------------------
184
+
185
+
186
+
187
+ # from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
188
+ # from diffusers.utils import load_image
189
+ # from PIL import Image
190
+ # import torch
191
+ # import numpy as np
192
+ # import cv2
193
+ # import gradio as gr
194
+ # from torchvision import transforms
195
+
196
+ # controlnet = ControlNetModel.from_pretrained(
197
+ # "briaai/BRIA-2.2-ControlNet-Canny",
198
+ # torch_dtype=torch.float16
199
+ # ).to('cuda')
200
+
201
+ # pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
202
+ # "briaai/BRIA-2.2",
203
+ # controlnet=controlnet,
204
+ # torch_dtype=torch.float16,
205
+ # device_map='auto',
206
+ # low_cpu_mem_usage=True,
207
+ # offload_state_dict=True,
208
+ # ).to('cuda')
209
+ # pipe.scheduler = EulerAncestralDiscreteScheduler(
210
+ # beta_start=0.00085,
211
+ # beta_end=0.012,
212
+ # beta_schedule="scaled_linear",
213
+ # num_train_timesteps=1000,
214
+ # steps_offset=1
215
+ # )
216
+ # # pipe.enable_freeu(b1=1.1, b2=1.1, s1=0.5, s2=0.7)
217
+ # pipe.enable_xformers_memory_efficient_attention()
218
+ # pipe.force_zeros_for_empty_prompt = False
219
+
220
+ # low_threshold = 100
221
+ # high_threshold = 200
222
+
223
+ # def resize_image(image):
224
+ # image = image.convert('RGB')
225
+ # current_size = image.size
226
+ # if current_size[0] > current_size[1]:
227
+ # center_cropped_image = transforms.functional.center_crop(image, (current_size[1], current_size[1]))
228
+ # else:
229
+ # center_cropped_image = transforms.functional.center_crop(image, (current_size[0], current_size[0]))
230
+ # resized_image = transforms.functional.resize(center_cropped_image, (1024, 1024))
231
+ # return resized_image
232
+
233
+ # def get_canny_filter(image):
234
+
235
+ # if not isinstance(image, np.ndarray):
236
+ # image = np.array(image)
237
+
238
+ # image = cv2.Canny(image, low_threshold, high_threshold)
239
+ # image = image[:, :, None]
240
+ # image = np.concatenate([image, image, image], axis=2)
241
+ # canny_image = Image.fromarray(image)
242
+ # return canny_image
243
+
244
+ # def process(input_image, prompt, negative_prompt, num_steps, controlnet_conditioning_scale, seed):
245
+ # generator = torch.manual_seed(seed)
246
+
247
+ # # resize input_image to 1024x1024
248
+ # input_image = resize_image(input_image)
249
+
250
+ # canny_image = get_canny_filter(input_image)
251
+
252
+ # images = pipe(
253
+ # prompt, negative_prompt=negative_prompt, image=canny_image, num_inference_steps=num_steps, controlnet_conditioning_scale=float(controlnet_conditioning_scale),
254
+ # generator=generator,
255
+ # ).images
256
+
257
+ # return [canny_image,images[0]]
258
+
259
+ # block = gr.Blocks().queue()
260
+
261
+ # with block:
262
+ # gr.Markdown("## BRIA 2.2 ControlNet Canny")
263
+ # gr.HTML('''
264
+ # <p style="margin-bottom: 10px; font-size: 94%">
265
+ # This is a demo for ControlNet Canny that using
266
+ # <a href="https://huggingface.co/briaai/BRIA-2.2" target="_blank">BRIA 2.2 text-to-image model</a> as backbone.
267
+ # Trained on licensed data, BRIA 2.2 provide full legal liability coverage for copyright and privacy infringement.
268
+ # </p>
269
+ # ''')
270
+ # with gr.Row():
271
+ # with gr.Column():
272
+ # input_image = gr.Image(sources=None, type="pil") # None for upload, ctrl+v and webcam
273
+ # prompt = gr.Textbox(label="Prompt")
274
+ # negative_prompt = gr.Textbox(label="Negative prompt", value="Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers")
275
+ # num_steps = gr.Slider(label="Number of steps", minimum=25, maximum=100, value=50, step=1)
276
+ # controlnet_conditioning_scale = gr.Slider(label="ControlNet conditioning scale", minimum=0.1, maximum=2.0, value=1.0, step=0.05)
277
+ # seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True,)
278
+ # run_button = gr.Button(value="Run")
279
+
280
+
281
+ # with gr.Column():
282
+ # result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", columns=[2], height='auto')
283
+ # ips = [input_image, prompt, negative_prompt, num_steps, controlnet_conditioning_scale, seed]
284
+ # run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
285
+
286
+ # block.launch(debug = True)