ModelSwap-1 / app.py
ssboost's picture
Update app.py
ff749ef verified
raw
history blame
6.55 kB
import insightface
import os
import onnxruntime
import cv2
import gfpgan
import tempfile
import time
import gradio as gr
import sys
from torchvision.transforms import functional
from PIL import Image
# ์ฐธ์กฐ์ฝ”๋“œ์—์„œ ์‚ฌ์šฉ๋œ ๋ชจ๋“ˆ ์ž„ํฌํŠธ ์ˆ˜์ •
sys.modules["torchvision.transforms.functional_tensor"] = functional
class Predictor:
def __init__(self):
self.setup()
def setup(self):
os.makedirs('models', exist_ok=True)
os.chdir('models')
if not os.path.exists('GFPGANv1.4.pth'):
os.system(
'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth'
)
if not os.path.exists('inswapper_128.onnx'):
os.system(
'wget https://huggingface.co/ashleykleynhans/inswapper/resolve/main/inswapper_128.onnx'
)
os.chdir('..') # ๋””๋ ‰ํ† ๋ฆฌ ๋ณ€๊ฒฝ ์™„๋ฃŒ
"""๐Ÿ’Ž Load the model into memory to make running multiple predictions efficient"""
self.face_swapper = insightface.model_zoo.get_model('models/inswapper_128.onnx',
providers=onnxruntime.get_available_providers())
# self.face_swapper.prepare(ctx_id=0, det_size=(640, 640)) # ์ด ์ค„์„ ์ œ๊ฑฐํ•ฉ๋‹ˆ๋‹ค.
self.face_enhancer = gfpgan.GFPGANer(model_path='models/GFPGANv1.4.pth', upscale=1)
self.face_analyser = insightface.app.FaceAnalysis(name='buffalo_l')
self.face_analyser.prepare(ctx_id=0, det_size=(640, 640))
def get_face_image(self, img_data, face):
# ์–ผ๊ตด ์˜์—ญ์„ ์ž˜๋ผ๋‚ด๊ธฐ
x1, y1, x2, y2 = [int(coord) for coord in face.bbox]
face_img = img_data[y1:y2, x1:x2]
return face_img
def predict(self, input_image_path, swap_image_path):
"""๐Ÿงถ Run a single prediction on the model"""
try:
frame = cv2.imread(input_image_path)
if frame is None:
print("โŒ Target image could not be read.")
return None
analysed = self.face_analyser.get(frame)
if not analysed:
print("โŒ No face found in target image.")
return None
face = max(analysed, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))
target_face_img = self.get_face_image(frame, face)
swap_frame = cv2.imread(swap_image_path)
if swap_frame is None:
print("โŒ Swap image could not be read.")
return None
swap_analysed = self.face_analyser.get(swap_frame)
if not swap_analysed:
print("โŒ No face found in swap image.")
return None
swap_face = max(swap_analysed, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))
swap_face_img = self.get_face_image(swap_frame, swap_face)
# ์–ผ๊ตด ๊ต์ฒด ์ˆ˜ํ–‰
result = self.face_swapper.get(frame, face, swap_face, paste_back=True)
# ์–ผ๊ตด ํ–ฅ์ƒ ์ˆ˜ํ–‰
_, _, result = self.face_enhancer.enhance(
result,
paste_back=True
)
out_path = os.path.join(tempfile.mkdtemp(), f"{str(int(time.time()))}.jpg")
cv2.imwrite(out_path, result)
return out_path
except Exception as e:
print(f"{e}")
return None
# Predictor ํด๋ž˜์Šค ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ
predictor = Predictor()
# CSS ๋ฐ ํ…Œ๋งˆ ์„ค์ •
css = """
/* "Swap Faces" ๋ฒ„ํŠผ ์Šคํƒ€์ผ */
button#swap-button {
background-color: #FB923C !important; /* ์ฃผํ™ฉ์ƒ‰ ๋ฐฐ๊ฒฝ */
color: white !important; /* ํฐ์ƒ‰ ๊ธ€์”จ */
}
/* "์ด๋ฏธ์ง€ ๋‹ค์šด๋กœ๋“œ (JPG)" ๋ฒ„ํŠผ ์Šคํƒ€์ผ */
button#download-button {
background-color: #FB923C !important; /* ์ฃผํ™ฉ์ƒ‰ ๋ฐฐ๊ฒฝ */
color: white !important; /* ํฐ์ƒ‰ ๊ธ€์”จ */
}
/* ํ•„์š”์— ๋”ฐ๋ผ ์ถ”๊ฐ€์ ์ธ ์Šคํƒ€์ผ์„ ์—ฌ๊ธฐ์— ์ž‘์„ฑํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค */
"""
demo_theme = gr.themes.Soft(
primary_hue=gr.themes.Color(
c50="#FFF7ED",
c100="#FFEDD5",
c200="#FED7AA",
c300="#FDBA74",
c400="#FB923C",
c500="#F97316",
c600="#EA580C",
c700="#C2410C",
c800="#9A3412",
c900="#7C2D12",
c950="#431407",
),
secondary_hue="zinc",
neutral_hue="zinc",
font=("Pretendard", "sans-serif")
)
# JPG ๋‹ค์šด๋กœ๋“œ ๊ธฐ๋Šฅ ๊ตฌํ˜„
def save_as_jpg(file_path):
try:
if file_path is None:
return None
# ํŒŒ์ผ ๊ฒฝ๋กœ๋ฅผ ๋ฐ›์•„ PIL ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜
img = Image.open(file_path)
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
img.save(tmp, format="JPEG")
tmp_path = tmp.name
return tmp_path # ํŒŒ์ผ ๊ฒฝ๋กœ ๋ฐ˜ํ™˜
except Exception as error:
print(f"Error saving as JPG: {error}")
return None
# Clear ํ•จ์ˆ˜: ์ž…๋ ฅ ๋ฐ ์ถœ๋ ฅ ์ดˆ๊ธฐํ™”
def clear_all():
return [None, None, None]
# Gradio Interface ๊ตฌ์„ฑ
with gr.Blocks(theme=demo_theme, css=css) as demo:
with gr.Row():
# ์™ผ์ชฝ ์„น์…˜: ์ž…๋ ฅ
with gr.Column(scale=1):
target_image = gr.Image(
type="filepath",
label="์–ผ๊ตด์„ ๋ณ€๊ฒฝํ•  ์ด๋ฏธ์ง€"
)
swap_image = gr.Image(
type="filepath",
label="๊ต์ฒดํ•  ์–ผ๊ตด"
)
swap_button = gr.Button("์–ผ๊ตด ๊ต์ฒด", elem_id="swap-button")
clear_button = gr.Button("๋ฆฌ์…‹ํ•˜๊ธฐ")
# ์˜ค๋ฅธ์ชฝ ์„น์…˜: ์ถœ๋ ฅ
with gr.Column(scale=1):
result_image = gr.Image(
type="filepath",
label="๊ฒฐ๊ณผ ์ด๋ฏธ์ง€"
)
download_jpg_button = gr.Button("JPG๋กœ ๋ณ€ํ™˜ํ•˜๊ธฐ", elem_id="download-button")
download_file = gr.File(label="JPG ์ด๋ฏธ์ง€ ๋‹ค์šด๋ฐ›๊ธฐ")
# ๋ฒ„ํŠผ ํด๋ฆญ ์‹œ ์˜ˆ์ธก ํ•จ์ˆ˜ ํ˜ธ์ถœ
swap_button.click(
fn=predictor.predict,
inputs=[target_image, swap_image],
outputs=result_image
)
# ๋ฆฌ์…‹ํ•˜๊ธฐ ๋ฒ„ํŠผ ํด๋ฆญ ์‹œ ์ž…๋ ฅ ๋ฐ ์ถœ๋ ฅ ์ด๋ฏธ์ง€ ์ดˆ๊ธฐํ™”
clear_button.click(
fn=clear_all,
inputs=None,
outputs=[target_image, swap_image, result_image]
)
# JPG ๋‹ค์šด๋กœ๋“œ ๋ฒ„ํŠผ ํด๋ฆญ ์‹œ ํŒŒ์ผ ์ƒ์„ฑ
download_jpg_button.click(
fn=save_as_jpg,
inputs=result_image,
outputs=download_file
)
# Gradio Interface ์‹คํ–‰
demo.launch()