zhiweili commited on
Commit
423ba5e
·
1 Parent(s): ccf29a6

add pre upscale

Browse files
Files changed (2) hide show
  1. app_base.py +22 -9
  2. upscale.py +27 -0
app_base.py CHANGED
@@ -10,6 +10,8 @@ from segment_utils import(
10
  )
11
  from enhance_utils import enhance_image
12
 
 
 
13
  DEFAULT_SRC_PROMPT = "a person"
14
  DEFAULT_EDIT_PROMPT = "a person with perfect face"
15
 
@@ -31,16 +33,24 @@ def create_demo() -> gr.Blocks:
31
  start_step: int,
32
  guidance_scale: float,
33
  generate_size: int,
34
- pre_enhance: bool = True,
35
- pre_enhance_scale: int = 2,
 
 
36
  ):
37
  w2 = 1.0
38
  run_task_time = 0
39
  time_cost_str = ''
40
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
41
- if pre_enhance:
42
- input_image = enhance_image(input_image, enhance_face=True, scale=pre_enhance_scale)
43
- input_image = input_image.resize((generate_size, generate_size))
 
 
 
 
 
 
44
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
45
  run_model = base_run
46
  res_image = run_model(
@@ -56,7 +66,7 @@ def create_demo() -> gr.Blocks:
56
  guidance_scale,
57
  )
58
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
59
- enhanced_image = enhance_image(res_image)
60
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
61
 
62
  return enhanced_image, res_image, time_cost_str
@@ -79,6 +89,11 @@ def create_demo() -> gr.Blocks:
79
  input_image_prompt = gr.Textbox(lines=1, label="Input Image Prompt", value=DEFAULT_SRC_PROMPT)
80
  edit_prompt = gr.Textbox(lines=1, label="Edit Prompt", value=DEFAULT_EDIT_PROMPT)
81
  category = gr.Textbox(label="Category", value=DEFAULT_CATEGORY, visible=False)
 
 
 
 
 
82
  with gr.Column():
83
  num_steps = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Num Steps")
84
  start_step = gr.Slider(minimum=1, maximum=100, value=30, step=1, label="Start Step")
@@ -87,8 +102,6 @@ def create_demo() -> gr.Blocks:
87
  generate_size = gr.Number(label="Generate Size", value=512)
88
  mask_expansion = gr.Number(label="Mask Expansion", value=50, visible=True)
89
  mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
90
- pre_enhance = gr.Checkbox(label="Pre Enhance", value=True)
91
- pre_enhance_scale = gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Pre Enhance Scale")
92
  with gr.Column():
93
  seed = gr.Number(label="Seed", value=8)
94
  w1 = gr.Number(label="W1", value=1.5)
@@ -112,7 +125,7 @@ def create_demo() -> gr.Blocks:
112
  outputs=[origin_area_image, croper],
113
  ).success(
114
  fn=image_to_image,
115
- inputs=[origin_area_image, input_image_prompt, edit_prompt,seed,w1, num_steps, start_step, guidance_scale, generate_size, pre_enhance, pre_enhance_scale],
116
  outputs=[enhanced_image, generated_image, generated_cost],
117
  ).success(
118
  fn=restore_result,
 
10
  )
11
  from enhance_utils import enhance_image
12
 
13
+ from upscale import upscale_image
14
+
15
  DEFAULT_SRC_PROMPT = "a person"
16
  DEFAULT_EDIT_PROMPT = "a person with perfect face"
17
 
 
33
  start_step: int,
34
  guidance_scale: float,
35
  generate_size: int,
36
+ enhance_scale: int = 2,
37
+ pre_upscale: bool = True,
38
+ upscale_prompt: str,
39
+ pre_upscale_steps: int = 10,
40
  ):
41
  w2 = 1.0
42
  run_task_time = 0
43
  time_cost_str = ''
44
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
45
+ if pre_upscale:
46
+ pre_upscale_start_size = generate_size // 4
47
+ input_image = upscale_image(
48
+ input_image,
49
+ upscale_prompt,
50
+ start_size=pre_upscale_start_size,
51
+ upscale_steps=pre_upscale_steps,
52
+ seed=seed,
53
+ )
54
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
55
  run_model = base_run
56
  res_image = run_model(
 
66
  guidance_scale,
67
  )
68
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
69
+ enhanced_image = enhance_image(res_image, scale = enhance_scale)
70
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
71
 
72
  return enhanced_image, res_image, time_cost_str
 
89
  input_image_prompt = gr.Textbox(lines=1, label="Input Image Prompt", value=DEFAULT_SRC_PROMPT)
90
  edit_prompt = gr.Textbox(lines=1, label="Edit Prompt", value=DEFAULT_EDIT_PROMPT)
91
  category = gr.Textbox(label="Category", value=DEFAULT_CATEGORY, visible=False)
92
+ with gr.Accordion("Advanced Options", open=False):
93
+ enhance_scale = gr.Number(label="Enhance Scale", value=2)
94
+ pre_upscale = gr.Checkbox(label="Pre Upscale", value=True)
95
+ upscale_prompt = gr.Textbox(lines=1, label="Upscale Prompt", value="a person with pefect face")
96
+ pre_upscale_steps = gr.Number(label="Pre Upscale Steps", value=10)
97
  with gr.Column():
98
  num_steps = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Num Steps")
99
  start_step = gr.Slider(minimum=1, maximum=100, value=30, step=1, label="Start Step")
 
102
  generate_size = gr.Number(label="Generate Size", value=512)
103
  mask_expansion = gr.Number(label="Mask Expansion", value=50, visible=True)
104
  mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
 
 
105
  with gr.Column():
106
  seed = gr.Number(label="Seed", value=8)
107
  w1 = gr.Number(label="W1", value=1.5)
 
125
  outputs=[origin_area_image, croper],
126
  ).success(
127
  fn=image_to_image,
128
+ inputs=[origin_area_image, input_image_prompt, edit_prompt,seed,w1, num_steps, start_step, guidance_scale, generate_size, enhance_scale, pre_upscale, upscale_prompt, pre_upscale_steps],
129
  outputs=[enhanced_image, generated_image, generated_cost],
130
  ).success(
131
  fn=restore_result,
upscale.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from PIL import Image
4
+ from diffusers import StableDiffusionUpscalePipeline
5
+
6
+ device = "cuda" if torch.cuda.is_available() else "cpu"
7
+ model_id = "stabilityai/stable-diffusion-x4-upscaler"
8
+ upscale_pipe = StableDiffusionUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16)
9
+ upscale_pipe = upscale_pipe.to(device)
10
+
11
+ def upscale_image(
12
+ input_image: Image,
13
+ prompt: str,
14
+ start_size: int = 128,
15
+ upscale_steps: int = 30,
16
+ seed: int = 42,
17
+ ):
18
+ generator = torch.Generator().manual_seed(seed)
19
+ input_image = input_image.resize((start_size, start_size))
20
+ upscaled_image = upscale_pipe(
21
+ prompt=prompt,
22
+ image=input_image,
23
+ num_inference_steps=upscale_steps,
24
+ generator=generator,
25
+ ).images[0]
26
+
27
+ return upscaled_image