Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import contextlib
|
2 |
import gc
|
3 |
import json
|
@@ -179,6 +180,7 @@ image_encoder = CLIPVisionModelWithProjection.from_pretrained(os.path.join(model
|
|
179 |
global pipeline
|
180 |
global MultiResNetModel
|
181 |
|
|
|
182 |
def load_ckpt():
|
183 |
global pipeline
|
184 |
global MultiResNetModel
|
@@ -293,6 +295,7 @@ def load_ckpt():
|
|
293 |
|
294 |
global cur_style
|
295 |
cur_style = 'line + shadow'
|
|
|
296 |
def change_ckpt(style):
|
297 |
global pipeline
|
298 |
global MultiResNetModel
|
@@ -334,6 +337,7 @@ def change_ckpt(style):
|
|
334 |
|
335 |
load_ckpt()
|
336 |
|
|
|
337 |
def fix_random_seeds(seed):
|
338 |
random.seed(seed)
|
339 |
np.random.seed(seed)
|
@@ -349,6 +353,7 @@ def process_multi_images(files):
|
|
349 |
imgs.append(img)
|
350 |
return imgs
|
351 |
|
|
|
352 |
def extract_lines(image):
|
353 |
src = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
|
354 |
|
@@ -373,16 +378,17 @@ def extract_lines(image):
|
|
373 |
torch.cuda.empty_cache()
|
374 |
return outimg
|
375 |
|
|
|
376 |
def extract_line_image(query_image_, resolution):
|
377 |
tar_width, tar_height = resolution
|
378 |
query_image = query_image_.resize((tar_width, tar_height))
|
379 |
-
# query_image.save('/mnt/workspace/zhuangjunhao/cobra_code/ColorFlow/examples/line/example3/input.png')
|
380 |
query_image = query_image.convert('L').convert('RGB')
|
381 |
extracted_line = extract_lines(query_image)
|
382 |
extracted_line = extracted_line.convert('L').convert('RGB')
|
383 |
torch.cuda.empty_cache()
|
384 |
return extracted_line, Image.new('RGB', (tar_width, tar_height), 'black')
|
385 |
|
|
|
386 |
def extract_sketch_line_image(query_image_, input_style):
|
387 |
global cur_style
|
388 |
if input_style != cur_style:
|
@@ -418,6 +424,7 @@ def extract_sketch_line_image(query_image_, input_style):
|
|
418 |
|
419 |
return extracted_sketch_line.convert('RGB'), extracted_sketch_line.convert('RGB'), hint_mask, query_image_, extracted_sketch_line_ori.convert('RGB'), resolution
|
420 |
|
|
|
421 |
def colorize_image(extracted_line, reference_images, resolution, seed, num_inference_steps, top_k, hint_mask=None, hint_color=None, query_image_origin=None, extracted_image_ori=None):
|
422 |
if extracted_line is None:
|
423 |
gr.Info("Please preprocess the image first")
|
@@ -440,11 +447,6 @@ def colorize_image(extracted_line, reference_images, resolution, seed, num_infer
|
|
440 |
reference_images = [process_image(ref_image, tar_width, tar_height) for ref_image in reference_images]
|
441 |
query_patches_pil = process_image_Q_varres(query_image_origin, tar_width, tar_height)
|
442 |
reference_patches_pil = []
|
443 |
-
# Save reference_images
|
444 |
-
# save_path = '/mnt/workspace/zhuangjunhao/cobra_code/ColorFlow/examples/line/example3'
|
445 |
-
# os.makedirs(save_path, exist_ok=True)
|
446 |
-
# for idx, ref_image in enumerate(reference_images):
|
447 |
-
# ref_image.save(os.path.join(save_path, f'reference_image_{idx}.png'))
|
448 |
|
449 |
for reference_image in reference_images:
|
450 |
reference_patches_pil += process_image_ref_varres(reference_image, tar_width, tar_height)
|
@@ -695,4 +697,4 @@ with gr.Blocks() as demo:
|
|
695 |
)
|
696 |
|
697 |
|
698 |
-
demo.launch(
|
|
|
1 |
+
import spaces
|
2 |
import contextlib
|
3 |
import gc
|
4 |
import json
|
|
|
180 |
global pipeline
|
181 |
global MultiResNetModel
|
182 |
|
183 |
+
@spaces.GPU
|
184 |
def load_ckpt():
|
185 |
global pipeline
|
186 |
global MultiResNetModel
|
|
|
295 |
|
296 |
global cur_style
|
297 |
cur_style = 'line + shadow'
|
298 |
+
@spaces.GPU
|
299 |
def change_ckpt(style):
|
300 |
global pipeline
|
301 |
global MultiResNetModel
|
|
|
337 |
|
338 |
load_ckpt()
|
339 |
|
340 |
+
@spaces.GPU
|
341 |
def fix_random_seeds(seed):
|
342 |
random.seed(seed)
|
343 |
np.random.seed(seed)
|
|
|
353 |
imgs.append(img)
|
354 |
return imgs
|
355 |
|
356 |
+
@spaces.GPU
|
357 |
def extract_lines(image):
|
358 |
src = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
|
359 |
|
|
|
378 |
torch.cuda.empty_cache()
|
379 |
return outimg
|
380 |
|
381 |
+
@spaces.GPU
|
382 |
def extract_line_image(query_image_, resolution):
|
383 |
tar_width, tar_height = resolution
|
384 |
query_image = query_image_.resize((tar_width, tar_height))
|
|
|
385 |
query_image = query_image.convert('L').convert('RGB')
|
386 |
extracted_line = extract_lines(query_image)
|
387 |
extracted_line = extracted_line.convert('L').convert('RGB')
|
388 |
torch.cuda.empty_cache()
|
389 |
return extracted_line, Image.new('RGB', (tar_width, tar_height), 'black')
|
390 |
|
391 |
+
@spaces.GPU
|
392 |
def extract_sketch_line_image(query_image_, input_style):
|
393 |
global cur_style
|
394 |
if input_style != cur_style:
|
|
|
424 |
|
425 |
return extracted_sketch_line.convert('RGB'), extracted_sketch_line.convert('RGB'), hint_mask, query_image_, extracted_sketch_line_ori.convert('RGB'), resolution
|
426 |
|
427 |
+
@spaces.GPU(duration=120)
|
428 |
def colorize_image(extracted_line, reference_images, resolution, seed, num_inference_steps, top_k, hint_mask=None, hint_color=None, query_image_origin=None, extracted_image_ori=None):
|
429 |
if extracted_line is None:
|
430 |
gr.Info("Please preprocess the image first")
|
|
|
447 |
reference_images = [process_image(ref_image, tar_width, tar_height) for ref_image in reference_images]
|
448 |
query_patches_pil = process_image_Q_varres(query_image_origin, tar_width, tar_height)
|
449 |
reference_patches_pil = []
|
|
|
|
|
|
|
|
|
|
|
450 |
|
451 |
for reference_image in reference_images:
|
452 |
reference_patches_pil += process_image_ref_varres(reference_image, tar_width, tar_height)
|
|
|
697 |
)
|
698 |
|
699 |
|
700 |
+
demo.launch()
|