zhiweili
commited on
Commit
·
423ba5e
1
Parent(s):
ccf29a6
add pre upscale
Browse files- app_base.py +22 -9
- upscale.py +27 -0
app_base.py
CHANGED
@@ -10,6 +10,8 @@ from segment_utils import(
|
|
10 |
)
|
11 |
from enhance_utils import enhance_image
|
12 |
|
|
|
|
|
13 |
DEFAULT_SRC_PROMPT = "a person"
|
14 |
DEFAULT_EDIT_PROMPT = "a person with perfect face"
|
15 |
|
@@ -31,16 +33,24 @@ def create_demo() -> gr.Blocks:
|
|
31 |
start_step: int,
|
32 |
guidance_scale: float,
|
33 |
generate_size: int,
|
34 |
-
|
35 |
-
|
|
|
|
|
36 |
):
|
37 |
w2 = 1.0
|
38 |
run_task_time = 0
|
39 |
time_cost_str = ''
|
40 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
41 |
-
if
|
42 |
-
|
43 |
-
input_image =
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
45 |
run_model = base_run
|
46 |
res_image = run_model(
|
@@ -56,7 +66,7 @@ def create_demo() -> gr.Blocks:
|
|
56 |
guidance_scale,
|
57 |
)
|
58 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
59 |
-
enhanced_image = enhance_image(res_image)
|
60 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
61 |
|
62 |
return enhanced_image, res_image, time_cost_str
|
@@ -79,6 +89,11 @@ def create_demo() -> gr.Blocks:
|
|
79 |
input_image_prompt = gr.Textbox(lines=1, label="Input Image Prompt", value=DEFAULT_SRC_PROMPT)
|
80 |
edit_prompt = gr.Textbox(lines=1, label="Edit Prompt", value=DEFAULT_EDIT_PROMPT)
|
81 |
category = gr.Textbox(label="Category", value=DEFAULT_CATEGORY, visible=False)
|
|
|
|
|
|
|
|
|
|
|
82 |
with gr.Column():
|
83 |
num_steps = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Num Steps")
|
84 |
start_step = gr.Slider(minimum=1, maximum=100, value=30, step=1, label="Start Step")
|
@@ -87,8 +102,6 @@ def create_demo() -> gr.Blocks:
|
|
87 |
generate_size = gr.Number(label="Generate Size", value=512)
|
88 |
mask_expansion = gr.Number(label="Mask Expansion", value=50, visible=True)
|
89 |
mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
|
90 |
-
pre_enhance = gr.Checkbox(label="Pre Enhance", value=True)
|
91 |
-
pre_enhance_scale = gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Pre Enhance Scale")
|
92 |
with gr.Column():
|
93 |
seed = gr.Number(label="Seed", value=8)
|
94 |
w1 = gr.Number(label="W1", value=1.5)
|
@@ -112,7 +125,7 @@ def create_demo() -> gr.Blocks:
|
|
112 |
outputs=[origin_area_image, croper],
|
113 |
).success(
|
114 |
fn=image_to_image,
|
115 |
-
inputs=[origin_area_image, input_image_prompt, edit_prompt,seed,w1, num_steps, start_step, guidance_scale, generate_size,
|
116 |
outputs=[enhanced_image, generated_image, generated_cost],
|
117 |
).success(
|
118 |
fn=restore_result,
|
|
|
10 |
)
|
11 |
from enhance_utils import enhance_image
|
12 |
|
13 |
+
from upscale import upscale_image
|
14 |
+
|
15 |
DEFAULT_SRC_PROMPT = "a person"
|
16 |
DEFAULT_EDIT_PROMPT = "a person with perfect face"
|
17 |
|
|
|
33 |
start_step: int,
|
34 |
guidance_scale: float,
|
35 |
generate_size: int,
|
36 |
+
enhance_scale: int = 2,
|
37 |
+
pre_upscale: bool = True,
|
38 |
+
upscale_prompt: str,
|
39 |
+
pre_upscale_steps: int = 10,
|
40 |
):
|
41 |
w2 = 1.0
|
42 |
run_task_time = 0
|
43 |
time_cost_str = ''
|
44 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
45 |
+
if pre_upscale:
|
46 |
+
pre_upscale_start_size = generate_size // 4
|
47 |
+
input_image = upscale_image(
|
48 |
+
input_image,
|
49 |
+
upscale_prompt,
|
50 |
+
start_size=pre_upscale_start_size,
|
51 |
+
upscale_steps=pre_upscale_steps,
|
52 |
+
seed=seed,
|
53 |
+
)
|
54 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
55 |
run_model = base_run
|
56 |
res_image = run_model(
|
|
|
66 |
guidance_scale,
|
67 |
)
|
68 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
69 |
+
enhanced_image = enhance_image(res_image, scale = enhance_scale)
|
70 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
71 |
|
72 |
return enhanced_image, res_image, time_cost_str
|
|
|
89 |
input_image_prompt = gr.Textbox(lines=1, label="Input Image Prompt", value=DEFAULT_SRC_PROMPT)
|
90 |
edit_prompt = gr.Textbox(lines=1, label="Edit Prompt", value=DEFAULT_EDIT_PROMPT)
|
91 |
category = gr.Textbox(label="Category", value=DEFAULT_CATEGORY, visible=False)
|
92 |
+
with gr.Accordion("Advanced Options", open=False):
|
93 |
+
enhance_scale = gr.Number(label="Enhance Scale", value=2)
|
94 |
+
pre_upscale = gr.Checkbox(label="Pre Upscale", value=True)
|
95 |
+
upscale_prompt = gr.Textbox(lines=1, label="Upscale Prompt", value="a person with pefect face")
|
96 |
+
pre_upscale_steps = gr.Number(label="Pre Upscale Steps", value=10)
|
97 |
with gr.Column():
|
98 |
num_steps = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Num Steps")
|
99 |
start_step = gr.Slider(minimum=1, maximum=100, value=30, step=1, label="Start Step")
|
|
|
102 |
generate_size = gr.Number(label="Generate Size", value=512)
|
103 |
mask_expansion = gr.Number(label="Mask Expansion", value=50, visible=True)
|
104 |
mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
|
|
|
|
|
105 |
with gr.Column():
|
106 |
seed = gr.Number(label="Seed", value=8)
|
107 |
w1 = gr.Number(label="W1", value=1.5)
|
|
|
125 |
outputs=[origin_area_image, croper],
|
126 |
).success(
|
127 |
fn=image_to_image,
|
128 |
+
inputs=[origin_area_image, input_image_prompt, edit_prompt,seed,w1, num_steps, start_step, guidance_scale, generate_size, enhance_scale, pre_upscale, upscale_prompt, pre_upscale_steps],
|
129 |
outputs=[enhanced_image, generated_image, generated_cost],
|
130 |
).success(
|
131 |
fn=restore_result,
|
upscale.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
from diffusers import StableDiffusionUpscalePipeline
|
5 |
+
|
6 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
7 |
+
model_id = "stabilityai/stable-diffusion-x4-upscaler"
|
8 |
+
upscale_pipe = StableDiffusionUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16)
|
9 |
+
upscale_pipe = upscale_pipe.to(device)
|
10 |
+
|
11 |
+
def upscale_image(
|
12 |
+
input_image: Image,
|
13 |
+
prompt: str,
|
14 |
+
start_size: int = 128,
|
15 |
+
upscale_steps: int = 30,
|
16 |
+
seed: int = 42,
|
17 |
+
):
|
18 |
+
generator = torch.Generator().manual_seed(seed)
|
19 |
+
input_image = input_image.resize((start_size, start_size))
|
20 |
+
upscaled_image = upscale_pipe(
|
21 |
+
prompt=prompt,
|
22 |
+
image=input_image,
|
23 |
+
num_inference_steps=upscale_steps,
|
24 |
+
generator=generator,
|
25 |
+
).images[0]
|
26 |
+
|
27 |
+
return upscaled_image
|