|
|
|
from mtcnn.mtcnn import MTCNN |
|
from utils import * |
|
|
|
|
|
face_detector = MTCNN() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
css = """ |
|
.gradio-container {width: 85% !important} |
|
""" |
|
|
|
|
|
def onClick(cloth_image, pose_image, category, |
|
caption, request: gr.Request): |
|
if pose_image is None: |
|
yield None, f"no user image found !" |
|
return None, "no user image found !" |
|
elif cloth_image is None: |
|
yield None, f"no cloth image found !" |
|
return None, "no cloth image found !" |
|
try: |
|
faces = face_detector.detect_faces(pose_image[:,:,::-1]) |
|
if len(faces)==0: |
|
print(client_ip, 'faces num is 0! ', flush=True) |
|
yield None, "Fatal Error !!! No face detected in pose image !!! " |
|
return None, "Fatal Error !!! No face detected in pose image !!! " |
|
else: |
|
x, y, w, h = faces[0]["box"] |
|
H, W = pose_image.shape[:2] |
|
max_face_ratio = 1/3.3 |
|
if w/W>max_face_ratio or h/H>max_face_ratio: |
|
yield None, "Fatal Error !!! Headshot is not allowed in pose image!!!" |
|
return None, "Fatal Error !!! Headshot is not allowed in pose image!!!" |
|
|
|
uploads = upload_imgs(ApiUrl, UploadToken, cloth_image, pose_image) |
|
if uploads is None: |
|
yield None, "fail to upload" |
|
return None, "fail to upload" |
|
|
|
infId = publicFastSwap(ApiUrl, OpenId, ApiKey, uploads, category, caption) |
|
if not infId: |
|
yield None, "fail to public you task" |
|
return None, "fail to public you task" |
|
|
|
max_try = 30 |
|
wait_s = 3 |
|
yield None, "start to process, please wait..." |
|
for i in range(max_try): |
|
time.sleep(wait_s) |
|
taskStatus = getTaskRes(ApiUrl, infId) |
|
if taskStatus is None: continue |
|
|
|
status = taskStatus['status'] |
|
if status in ['FAILED', 'CANCELLED', 'TIMED_OUT', ]: |
|
yield None, f"task failed, query {i}, status {status}" |
|
return None, f"task failed, query {i}, status {status}" |
|
elif status in ['IN_QUEUE', 'IN_PROGRESS', 'IN_QUEUE', ]: |
|
pass |
|
yield None, f"task is on processing, query {i}, status {status}, please do not exit !!!" |
|
elif status=='COMPLETED': |
|
out = taskStatus['output']['job_results']['output1'] |
|
yield out, f"task is COMPLETED" |
|
return out, f"{i} task COMPLETED" |
|
yield None, "fail to query task.." |
|
return None, "fail to query task.." |
|
|
|
|
|
except Exception as e: |
|
print(e) |
|
return None, "fail to create task" |
|
|
|
|
|
with gr.Blocks(css=css) as demo: |
|
|
|
|
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Column(): |
|
cloth_image = gr.Image(value=None, type="numpy", label="cloth") |
|
with gr.Column(): |
|
with gr.Column(): |
|
pose_image = gr.Image(value=None, type="numpy", label="user photo") |
|
with gr.Column(): |
|
with gr.Column(): |
|
category = gr.Dropdown(value="upper_cloth", choices=["upper_cloth", |
|
"lower_cloth", "full_body", "dresses"], interactive=True) |
|
caption = gr.Textbox(value="", interactive=True, label='cloth caption') |
|
|
|
info_text = gr.Textbox(value="", interactive=False, label='runtime information') |
|
run_button = gr.Button(value="Run") |
|
res_image = gr.Image(label="result image", value=None, type="filepath") |
|
|
|
run_button.click(fn=onClick, inputs=[cloth_image, pose_image, |
|
category, caption, ], |
|
outputs=[res_image, info_text, ]) |
|
|
|
if __name__ == "__main__": |
|
|
|
demo.queue(max_size=50) |
|
demo.launch(server_name='0.0.0.0') |
|
|