File size: 5,199 Bytes
6baa93c ff163f1 6baa93c 3ab17e6 6baa93c deca47d ff163f1 6baa93c deca47d ff163f1 deca47d 6baa93c ca58311 6baa93c 1f492ec 6baa93c deca47d 6baa93c e8cf780 ff163f1 6baa93c e8cf780 6baa93c ec80eac 6baa93c ec80eac 6baa93c ec80eac 6baa93c ff163f1 6baa93c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import spaces
import gradio as gr
import time
import torch
from PIL import Image
from segment_utils import(
segment_image,
restore_result,
)
from enhance_utils import enhance_image
DEFAULT_SRC_PROMPT = "a person"
DEFAULT_EDIT_PROMPT = "a person with perfect face"
DEFAULT_CATEGORY = "face"
device = "cuda" if torch.cuda.is_available() else "cpu"
def create_demo() -> gr.Blocks:
from inversion_run_base import run as base_run
@spaces.GPU(duration=15)
def image_to_image(
input_image: Image,
input_image_prompt: str,
edit_prompt: str,
seed: int,
w1: float,
num_steps: int,
start_step: int,
guidance_scale: float,
generate_size: int,
pre_enhance: bool = True,
pre_enhance_scale: int = 2,
):
w2 = 1.0
run_task_time = 0
time_cost_str = ''
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
if pre_enhance:
input_image = enhance_image(input_image, enhance_face=True, scale=pre_enhance_scale)
input_image = input_image.resize((generate_size, generate_size))
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
run_model = base_run
res_image = run_model(
input_image,
input_image_prompt,
edit_prompt,
generate_size,
seed,
w1,
w2,
num_steps,
start_step,
guidance_scale,
)
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
enhanced_image = enhance_image(res_image)
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
return enhanced_image, res_image, time_cost_str
def get_time_cost(run_task_time, time_cost_str):
now_time = int(time.time()*1000)
if run_task_time == 0:
time_cost_str = 'start'
else:
if time_cost_str != '':
time_cost_str += f'-->'
time_cost_str += f'{now_time - run_task_time}'
run_task_time = now_time
return run_task_time, time_cost_str
with gr.Blocks() as demo:
croper = gr.State()
with gr.Row():
with gr.Column():
input_image_prompt = gr.Textbox(lines=1, label="Input Image Prompt", value=DEFAULT_SRC_PROMPT)
edit_prompt = gr.Textbox(lines=1, label="Edit Prompt", value=DEFAULT_EDIT_PROMPT)
category = gr.Textbox(label="Category", value=DEFAULT_CATEGORY, visible=False)
with gr.Column():
num_steps = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Num Steps")
start_step = gr.Slider(minimum=1, maximum=100, value=30, step=1, label="Start Step")
with gr.Accordion("Advanced Options", open=False):
guidance_scale = gr.Slider(minimum=0, maximum=20, value=0, step=0.5, label="Guidance Scale")
generate_size = gr.Number(label="Generate Size", value=512)
mask_expansion = gr.Number(label="Mask Expansion", value=50, visible=True)
mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
pre_enhance = gr.Checkbox(label="Pre Enhance", value=False)
pre_enhance_scale = gr.Number(label="Pre Enhance Scale", value=1)
with gr.Column():
seed = gr.Number(label="Seed", value=8)
w1 = gr.Number(label="W1", value=1.5)
g_btn = gr.Button("Edit Image")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil")
with gr.Column():
restored_image = gr.Image(label="Restored Image", format="png", type="pil", interactive=False)
download_path = gr.File(label="Download the output image", interactive=False)
with gr.Column():
origin_area_image = gr.Image(label="Origin Area Image", format="png", type="pil", interactive=False)
enhanced_image = gr.Image(label="Enhanced Image", format="png", type="pil", interactive=False)
generated_cost = gr.Textbox(label="Time cost by step (ms):", visible=True, interactive=False)
generated_image = gr.Image(label="Generated Image", format="png", type="pil", interactive=False)
g_btn.click(
fn=segment_image,
inputs=[input_image, category, generate_size, mask_expansion, mask_dilation],
outputs=[origin_area_image, croper],
).success(
fn=image_to_image,
inputs=[origin_area_image, input_image_prompt, edit_prompt,seed,w1, num_steps, start_step, guidance_scale, generate_size, pre_enhance, pre_enhance_scale],
outputs=[enhanced_image, generated_image, generated_cost],
).success(
fn=restore_result,
inputs=[croper, category, enhanced_image],
outputs=[restored_image, download_path],
)
return demo |