TalHach61 commited on
Commit
93e5c88
·
verified ·
1 Parent(s): 5001402

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +265 -0
app.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('./')
3
+
4
+ import gradio as gr
5
+ import spaces
6
+ import os
7
+ import sys
8
+ import subprocess
9
+ import numpy as np
10
+ from PIL import Image
11
+ import cv2
12
+ import torch
13
+ import random
14
+
15
+ os.system("pip install -e ./controlnet_aux")
16
+
17
+ from controlnet_aux import OpenposeDetector #, CannyDetector
18
+ from depth_anything_v2.dpt import DepthAnythingV2
19
+
20
+ from huggingface_hub import hf_hub_download
21
+
22
+ from huggingface_hub import login
23
+ hf_token = os.environ.get("HF_TOKEN")
24
+ login(token=hf_token)
25
+
26
+ MAX_SEED = np.iinfo(np.int32).max
27
+
28
+ try:
29
+ local_dir = os.path.dirname(__file__)
30
+ except:
31
+ local_dir = '.'
32
+
33
+ hf_hub_download(repo_id="briaai/BRIA-3.1", filename='pipeline_bria.py', local_dir=local_dir)
34
+ hf_hub_download(repo_id="briaai/BRIA-3.1", filename='transformer_bria.py', local_dir=local_dir)
35
+ hf_hub_download(repo_id="briaai/BRIA-3.1", filename='bria_utils.py', local_dir=local_dir)
36
+ hf_hub_download(repo_id="briaai/BRIA-3.1-ControlNet-Union", filename='pipeline_bria_controlnet.py', local_dir=local_dir)
37
+ hf_hub_download(repo_id="briaai/BRIA-3.1-ControlNet-Union", filename='controlnet_bria.py', local_dir=local_dir)
38
+
39
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
40
+ if randomize_seed:
41
+ seed = random.randint(0, MAX_SEED)
42
+ return seed
43
+
44
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
45
+ model_configs = {
46
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
47
+ 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
48
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
49
+ 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
50
+ }
51
+
52
+ RATIO_CONFIGS_1024 = {
53
+ 0.6666666666666666: {"width": 832, "height": 1248},
54
+ 0.7432432432432432: {"width": 880, "height": 1184},
55
+ 0.8028169014084507: {"width": 912, "height": 1136},
56
+ 1.0: {"width": 1024, "height": 1024},
57
+ 1.2456140350877194: {"width": 1136, "height": 912},
58
+ 1.3454545454545455: {"width": 1184, "height": 880},
59
+ 1.4339622641509433: {"width": 1216, "height": 848},
60
+ 1.5: {"width": 1248, "height": 832},
61
+ 1.5490196078431373: {"width": 1264, "height": 816},
62
+ 1.62: {"width": 1296, "height": 800},
63
+ 1.7708333333333333: {"width": 1360, "height": 768},
64
+ }
65
+
66
+ encoder = 'vitl'
67
+ model = DepthAnythingV2(**model_configs[encoder])
68
+ filepath = hf_hub_download(repo_id=f"depth-anything/Depth-Anything-V2-Large", filename=f"depth_anything_v2_vitl.pth", repo_type="model")
69
+ state_dict = torch.load(filepath, map_location="cpu")
70
+ model.load_state_dict(state_dict)
71
+ model = model.to(DEVICE).eval()
72
+
73
+ import torch
74
+ from diffusers.utils import load_image
75
+ from controlnet_bria import BriaControlNetModel, BriaMultiControlNetModel
76
+ from pipeline_bria_controlnet import BriaControlNetPipeline
77
+ import PIL.Image as Image
78
+
79
+ base_model = 'briaai/BRIA-3.1'
80
+ controlnet_model = 'briaai/BRIA-3.1-ControlNet-Union'
81
+ controlnet = BriaControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
82
+ pipe = BriaControlNetPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch.bfloat16, trust_remote_code=True)
83
+ pipe = pipe.to(device="cuda", dtype=torch.bfloat16)
84
+
85
+ mode_mapping = {
86
+ "depth": 0,
87
+ "canny": 1,
88
+ "colorgrid": 2,
89
+ "recolor": 3,
90
+ "tile": 4,
91
+ "pose": 5,
92
+ }
93
+ strength_mapping = {
94
+ "depth": 1.0,
95
+ "canny": 1.0,
96
+ "colorgrid": 1.0,
97
+ "recolor": 1.0,
98
+ "tile": 1.0,
99
+ "pose": 1.0,
100
+ }
101
+ open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
102
+
103
+ torch.backends.cuda.matmul.allow_tf32 = True
104
+ pipe.enable_model_cpu_offload() # for saving memory
105
+
106
+ def convert_from_image_to_cv2(img: Image) -> np.ndarray:
107
+ return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
108
+
109
+ def convert_from_cv2_to_image(img: np.ndarray) -> Image:
110
+ return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
111
+
112
+ def extract_depth(image):
113
+ image = np.asarray(image)
114
+ depth = model.infer_image(image[:, :, ::-1])
115
+ depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
116
+ depth = depth.astype(np.uint8)
117
+ gray_depth = Image.fromarray(depth).convert('RGB')
118
+ return gray_depth
119
+
120
+ def extract_openpose(img):
121
+ processed_image_open_pose = open_pose(img, hand_and_face=True)
122
+ processed_image_open_pose = processed_image_open_pose.resize(img.size)
123
+ return processed_image_open_pose
124
+
125
+ def extract_canny(input_image):
126
+ image = np.array(input_image)
127
+ image = cv2.Canny(image, 100, 200)
128
+ image = image[:, :, None]
129
+ image = np.concatenate([image, image, image], axis=2)
130
+ canny_image = Image.fromarray(image)
131
+ return canny_image
132
+
133
+
134
+ def convert_to_grayscale(image):
135
+ gray_image = image.convert('L').convert('RGB')
136
+ return gray_image
137
+
138
+ def tile(downscale_factor, input_image):
139
+ control_image = input_image.resize((input_image.size[0] // downscale_factor, input_image.size[1] // downscale_factor)).resize(input_image.size, Image.NEAREST)
140
+ return control_image
141
+
142
+ def resize_img(control_image):
143
+ image_ratio = control_image.width / control_image.height
144
+ ratio = min(RATIO_CONFIGS_1024.keys(), key=lambda k: abs(k - image_ratio))
145
+ to_height = RATIO_CONFIGS_1024[ratio]["height"]
146
+ to_width = RATIO_CONFIGS_1024[ratio]["width"]
147
+ resized_image = control_image.resize((to_width, to_height), resample=Image.Resampling.LANCZOS)
148
+ return resized_image
149
+
150
+ @spaces.GPU(duration=180)
151
+ def infer(image_in, prompt, inference_steps, guidance_scale, control_mode, control_strength, seed, progress=gr.Progress(track_tqdm=True)):
152
+ control_mode_num = mode_mapping[control_mode]
153
+
154
+ if image_in is not None:
155
+ image_in = resize_img(load_image(image_in))
156
+ if control_mode == "canny":
157
+ control_image = extract_canny(image_in)
158
+ elif control_mode == "depth":
159
+ control_image = extract_depth(image_in)
160
+ elif control_mode == "pose":
161
+ control_image = extract_openpose(image_in)
162
+ elif control_mode == "colorgrid":
163
+ control_image = tile(64, image_in)
164
+ elif control_mode == "recolor":
165
+ control_image = convert_to_grayscale(image_in)
166
+ elif control_mode == "tile":
167
+ control_image = tile(16, image_in)
168
+
169
+ control_image = resize_img(control_image)
170
+
171
+ width, height = control_image.size
172
+
173
+ image = pipe(
174
+ prompt,
175
+ control_image=control_image,
176
+ control_mode=control_mode_num,
177
+ width=width,
178
+ height=height,
179
+ controlnet_conditioning_scale=control_strength,
180
+ num_inference_steps=inference_steps,
181
+ guidance_scale=guidance_scale,
182
+ generator=torch.manual_seed(seed),
183
+ max_sequence_length=128,
184
+ negative_prompt="Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate"
185
+ ).images[0]
186
+
187
+ torch.cuda.empty_cache()
188
+
189
+ return image, control_image, gr.update(visible=True)
190
+
191
+
192
+ css="""
193
+ #col-container{
194
+ margin: 0 auto;
195
+ max-width: 1080px;
196
+ }
197
+ """
198
+ with gr.Blocks(css=css) as demo:
199
+ with gr.Column(elem_id="col-container"):
200
+ gr.Markdown("""
201
+ # BRIA-3.1-ControlNet-Union
202
+ A unified ControlNet for BRIA-3.1 model from Bria.ai.<br />
203
+ """)
204
+
205
+ with gr.Column():
206
+
207
+ with gr.Row():
208
+ with gr.Column():
209
+
210
+ # with gr.Row(equal_height=True):
211
+ # cond_in = gr.Image(label="Upload a processed control image", sources=["upload"], type="filepath")
212
+ image_in = gr.Image(label="Extract condition from a reference image (Optional)", sources=["upload"], type="filepath")
213
+
214
+ prompt = gr.Textbox(label="Prompt", value="best quality")
215
+
216
+ with gr.Accordion("Controlnet"):
217
+ control_mode = gr.Radio(
218
+ ["depth", "canny", "colorgrid", "recolor", "tile", "pose"], label="Mode", value="canny",
219
+ info="select the control mode, one for all"
220
+ )
221
+
222
+ control_strength = gr.Slider(
223
+ label="control strength",
224
+ minimum=0,
225
+ maximum=1.0,
226
+ step=0.05,
227
+ value=0.9,
228
+ )
229
+
230
+ seed = gr.Slider(
231
+ label="Seed",
232
+ minimum=0,
233
+ maximum=MAX_SEED,
234
+ step=1,
235
+ value=555,
236
+ )
237
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
238
+
239
+ with gr.Accordion("Advanced settings", open=False):
240
+ with gr.Column():
241
+ with gr.Row():
242
+ inference_steps = gr.Slider(label="Inference steps", minimum=1, maximum=50, step=1, value=50)
243
+ guidance_scale = gr.Slider(label="Guidance scale", minimum=1.0, maximum=10.0, step=0.1, value=5.0)
244
+
245
+ submit_btn = gr.Button("Submit")
246
+
247
+ with gr.Column():
248
+ result = gr.Image(label="Result")
249
+ processed_cond = gr.Image(label="Preprocessed Cond")
250
+
251
+ submit_btn.click(
252
+ fn=randomize_seed_fn,
253
+ inputs=[seed, randomize_seed],
254
+ outputs=seed,
255
+ queue=False,
256
+ api_name=False
257
+ ).then(
258
+ fn = infer,
259
+ inputs = [image_in, prompt, inference_steps, guidance_scale, control_mode, control_strength, seed],
260
+ outputs = [result, processed_cond],
261
+ show_api=False
262
+ )
263
+
264
+ demo.queue(api_open=False)
265
+ demo.launch()