JunhaoZhuang commited on
Commit
1468e82
·
verified ·
1 Parent(s): 7d73ac1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import contextlib
2
  import gc
3
  import json
@@ -179,6 +180,7 @@ image_encoder = CLIPVisionModelWithProjection.from_pretrained(os.path.join(model
179
  global pipeline
180
  global MultiResNetModel
181
 
 
182
  def load_ckpt():
183
  global pipeline
184
  global MultiResNetModel
@@ -293,6 +295,7 @@ def load_ckpt():
293
 
294
  global cur_style
295
  cur_style = 'line + shadow'
 
296
  def change_ckpt(style):
297
  global pipeline
298
  global MultiResNetModel
@@ -334,6 +337,7 @@ def change_ckpt(style):
334
 
335
  load_ckpt()
336
 
 
337
  def fix_random_seeds(seed):
338
  random.seed(seed)
339
  np.random.seed(seed)
@@ -349,6 +353,7 @@ def process_multi_images(files):
349
  imgs.append(img)
350
  return imgs
351
 
 
352
  def extract_lines(image):
353
  src = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
354
 
@@ -373,16 +378,17 @@ def extract_lines(image):
373
  torch.cuda.empty_cache()
374
  return outimg
375
 
 
376
  def extract_line_image(query_image_, resolution):
377
  tar_width, tar_height = resolution
378
  query_image = query_image_.resize((tar_width, tar_height))
379
- # query_image.save('/mnt/workspace/zhuangjunhao/cobra_code/ColorFlow/examples/line/example3/input.png')
380
  query_image = query_image.convert('L').convert('RGB')
381
  extracted_line = extract_lines(query_image)
382
  extracted_line = extracted_line.convert('L').convert('RGB')
383
  torch.cuda.empty_cache()
384
  return extracted_line, Image.new('RGB', (tar_width, tar_height), 'black')
385
 
 
386
  def extract_sketch_line_image(query_image_, input_style):
387
  global cur_style
388
  if input_style != cur_style:
@@ -418,6 +424,7 @@ def extract_sketch_line_image(query_image_, input_style):
418
 
419
  return extracted_sketch_line.convert('RGB'), extracted_sketch_line.convert('RGB'), hint_mask, query_image_, extracted_sketch_line_ori.convert('RGB'), resolution
420
 
 
421
  def colorize_image(extracted_line, reference_images, resolution, seed, num_inference_steps, top_k, hint_mask=None, hint_color=None, query_image_origin=None, extracted_image_ori=None):
422
  if extracted_line is None:
423
  gr.Info("Please preprocess the image first")
@@ -440,11 +447,6 @@ def colorize_image(extracted_line, reference_images, resolution, seed, num_infer
440
  reference_images = [process_image(ref_image, tar_width, tar_height) for ref_image in reference_images]
441
  query_patches_pil = process_image_Q_varres(query_image_origin, tar_width, tar_height)
442
  reference_patches_pil = []
443
- # Save reference_images
444
- # save_path = '/mnt/workspace/zhuangjunhao/cobra_code/ColorFlow/examples/line/example3'
445
- # os.makedirs(save_path, exist_ok=True)
446
- # for idx, ref_image in enumerate(reference_images):
447
- # ref_image.save(os.path.join(save_path, f'reference_image_{idx}.png'))
448
 
449
  for reference_image in reference_images:
450
  reference_patches_pil += process_image_ref_varres(reference_image, tar_width, tar_height)
@@ -695,4 +697,4 @@ with gr.Blocks() as demo:
695
  )
696
 
697
 
698
- demo.launch(server_name="0.0.0.0", server_port=52218)
 
1
+ import spaces
2
  import contextlib
3
  import gc
4
  import json
 
180
  global pipeline
181
  global MultiResNetModel
182
 
183
+ @spaces.GPU
184
  def load_ckpt():
185
  global pipeline
186
  global MultiResNetModel
 
295
 
296
  global cur_style
297
  cur_style = 'line + shadow'
298
+ @spaces.GPU
299
  def change_ckpt(style):
300
  global pipeline
301
  global MultiResNetModel
 
337
 
338
  load_ckpt()
339
 
340
+ @spaces.GPU
341
  def fix_random_seeds(seed):
342
  random.seed(seed)
343
  np.random.seed(seed)
 
353
  imgs.append(img)
354
  return imgs
355
 
356
+ @spaces.GPU
357
  def extract_lines(image):
358
  src = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
359
 
 
378
  torch.cuda.empty_cache()
379
  return outimg
380
 
381
+ @spaces.GPU
382
  def extract_line_image(query_image_, resolution):
383
  tar_width, tar_height = resolution
384
  query_image = query_image_.resize((tar_width, tar_height))
 
385
  query_image = query_image.convert('L').convert('RGB')
386
  extracted_line = extract_lines(query_image)
387
  extracted_line = extracted_line.convert('L').convert('RGB')
388
  torch.cuda.empty_cache()
389
  return extracted_line, Image.new('RGB', (tar_width, tar_height), 'black')
390
 
391
+ @spaces.GPU
392
  def extract_sketch_line_image(query_image_, input_style):
393
  global cur_style
394
  if input_style != cur_style:
 
424
 
425
  return extracted_sketch_line.convert('RGB'), extracted_sketch_line.convert('RGB'), hint_mask, query_image_, extracted_sketch_line_ori.convert('RGB'), resolution
426
 
427
+ @spaces.GPU(duration=120)
428
  def colorize_image(extracted_line, reference_images, resolution, seed, num_inference_steps, top_k, hint_mask=None, hint_color=None, query_image_origin=None, extracted_image_ori=None):
429
  if extracted_line is None:
430
  gr.Info("Please preprocess the image first")
 
447
  reference_images = [process_image(ref_image, tar_width, tar_height) for ref_image in reference_images]
448
  query_patches_pil = process_image_Q_varres(query_image_origin, tar_width, tar_height)
449
  reference_patches_pil = []
 
 
 
 
 
450
 
451
  for reference_image in reference_images:
452
  reference_patches_pil += process_image_ref_varres(reference_image, tar_width, tar_height)
 
697
  )
698
 
699
 
700
+ demo.launch()