JunhaoZhuang commited on
Commit
6ca968b
·
verified ·
1 Parent(s): 0defeef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -382,9 +382,6 @@ def extract_line_image(query_image_, resolution):
382
 
383
  @spaces.GPU
384
  def extract_sketch_line_image(query_image_, input_style):
385
- global cur_style
386
- if input_style != cur_style:
387
- change_ckpt(input_style)
388
 
389
  resolution = get_rate(query_image_)
390
  extracted_line, hint_mask = extract_line_image(query_image_, resolution)
@@ -417,7 +414,7 @@ def extract_sketch_line_image(query_image_, input_style):
417
  return extracted_sketch_line.convert('RGB'), extracted_sketch_line.convert('RGB'), hint_mask, query_image_, extracted_sketch_line_ori.convert('RGB'), resolution
418
 
419
  @spaces.GPU(duration=120)
420
- def colorize_image(extracted_line, reference_images, resolution, seed, num_inference_steps, top_k, hint_mask=None, hint_color=None, query_image_origin=None, extracted_image_ori=None):
421
  if extracted_line is None:
422
  gr.Info("Please preprocess the image first")
423
  raise ValueError("Please preprocess the image first")
@@ -427,6 +424,8 @@ def colorize_image(extracted_line, reference_images, resolution, seed, num_infer
427
  global pipeline
428
  global MultiResNetModel
429
  global cur_style
 
 
430
 
431
  tar_width, tar_height = resolution
432
 
@@ -678,7 +677,7 @@ with gr.Blocks() as demo:
678
  )
679
  colorize_button.click(
680
  colorize_image,
681
- inputs=[extracted_image, reference_images, resolution, seed, num_inference_steps, top_k, hint_mask, hint_color, query_image_origin, extracted_image_ori],
682
  outputs=output_gallery
683
  )
684
  with gr.Column():
 
382
 
383
  @spaces.GPU
384
  def extract_sketch_line_image(query_image_, input_style):
 
 
 
385
 
386
  resolution = get_rate(query_image_)
387
  extracted_line, hint_mask = extract_line_image(query_image_, resolution)
 
414
  return extracted_sketch_line.convert('RGB'), extracted_sketch_line.convert('RGB'), hint_mask, query_image_, extracted_sketch_line_ori.convert('RGB'), resolution
415
 
416
  @spaces.GPU(duration=120)
417
+ def colorize_image(input_style, extracted_line, reference_images, resolution, seed, num_inference_steps, top_k, hint_mask=None, hint_color=None, query_image_origin=None, extracted_image_ori=None):
418
  if extracted_line is None:
419
  gr.Info("Please preprocess the image first")
420
  raise ValueError("Please preprocess the image first")
 
424
  global pipeline
425
  global MultiResNetModel
426
  global cur_style
427
+ if input_style != cur_style:
428
+ change_ckpt(input_style)
429
 
430
  tar_width, tar_height = resolution
431
 
 
677
  )
678
  colorize_button.click(
679
  colorize_image,
680
+ inputs=[model_name, extracted_image, reference_images, resolution, seed, num_inference_steps, top_k, hint_mask, hint_color, query_image_origin, extracted_image_ori],
681
  outputs=output_gallery
682
  )
683
  with gr.Column():