Spaces:
Running
Running
import insightface | |
import os | |
import onnxruntime | |
import cv2 | |
import gfpgan | |
import tempfile | |
import time | |
import gradio as gr | |
import sys | |
from torchvision.transforms import functional | |
from PIL import Image | |
# ์ฐธ์กฐ์ฝ๋์์ ์ฌ์ฉ๋ ๋ชจ๋ ์ํฌํธ ์์ | |
sys.modules["torchvision.transforms.functional_tensor"] = functional | |
class Predictor: | |
def __init__(self): | |
self.setup() | |
def setup(self): | |
os.makedirs('models', exist_ok=True) | |
os.chdir('models') | |
if not os.path.exists('GFPGANv1.4.pth'): | |
os.system( | |
'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth' | |
) | |
if not os.path.exists('inswapper_128.onnx'): | |
os.system( | |
'wget https://huggingface.co/ashleykleynhans/inswapper/resolve/main/inswapper_128.onnx' | |
) | |
os.chdir('..') # ๋๋ ํ ๋ฆฌ ๋ณ๊ฒฝ ์๋ฃ | |
"""๐ Load the model into memory to make running multiple predictions efficient""" | |
self.face_swapper = insightface.model_zoo.get_model('models/inswapper_128.onnx', | |
providers=onnxruntime.get_available_providers()) | |
# self.face_swapper.prepare(ctx_id=0, det_size=(640, 640)) # ์ด ์ค์ ์ ๊ฑฐํฉ๋๋ค. | |
self.face_enhancer = gfpgan.GFPGANer(model_path='models/GFPGANv1.4.pth', upscale=1) | |
self.face_analyser = insightface.app.FaceAnalysis(name='buffalo_l') | |
self.face_analyser.prepare(ctx_id=0, det_size=(640, 640)) | |
def get_face_image(self, img_data, face): | |
# ์ผ๊ตด ์์ญ์ ์๋ผ๋ด๊ธฐ | |
x1, y1, x2, y2 = [int(coord) for coord in face.bbox] | |
face_img = img_data[y1:y2, x1:x2] | |
return face_img | |
def predict(self, input_image_path, swap_image_path): | |
"""๐งถ Run a single prediction on the model""" | |
try: | |
frame = cv2.imread(input_image_path) | |
if frame is None: | |
print("โ Target image could not be read.") | |
return None | |
analysed = self.face_analyser.get(frame) | |
if not analysed: | |
print("โ No face found in target image.") | |
return None | |
face = max(analysed, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1])) | |
target_face_img = self.get_face_image(frame, face) | |
swap_frame = cv2.imread(swap_image_path) | |
if swap_frame is None: | |
print("โ Swap image could not be read.") | |
return None | |
swap_analysed = self.face_analyser.get(swap_frame) | |
if not swap_analysed: | |
print("โ No face found in swap image.") | |
return None | |
swap_face = max(swap_analysed, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1])) | |
swap_face_img = self.get_face_image(swap_frame, swap_face) | |
# ์ผ๊ตด ๊ต์ฒด ์ํ | |
result = self.face_swapper.get(frame, face, swap_face, paste_back=True) | |
# ์ผ๊ตด ํฅ์ ์ํ | |
_, _, result = self.face_enhancer.enhance( | |
result, | |
paste_back=True | |
) | |
out_path = os.path.join(tempfile.mkdtemp(), f"{str(int(time.time()))}.jpg") | |
cv2.imwrite(out_path, result) | |
return out_path | |
except Exception as e: | |
print(f"{e}") | |
return None | |
# Predictor ํด๋์ค ์ธ์คํด์ค ์์ฑ | |
predictor = Predictor() | |
# CSS ๋ฐ ํ ๋ง ์ค์ | |
css = """ | |
/* "Swap Faces" ๋ฒํผ ์คํ์ผ */ | |
button#swap-button { | |
background-color: #FB923C !important; /* ์ฃผํฉ์ ๋ฐฐ๊ฒฝ */ | |
color: white !important; /* ํฐ์ ๊ธ์จ */ | |
} | |
/* "์ด๋ฏธ์ง ๋ค์ด๋ก๋ (JPG)" ๋ฒํผ ์คํ์ผ */ | |
button#download-button { | |
background-color: #FB923C !important; /* ์ฃผํฉ์ ๋ฐฐ๊ฒฝ */ | |
color: white !important; /* ํฐ์ ๊ธ์จ */ | |
} | |
/* ํ์์ ๋ฐ๋ผ ์ถ๊ฐ์ ์ธ ์คํ์ผ์ ์ฌ๊ธฐ์ ์์ฑํ ์ ์์ต๋๋ค */ | |
""" | |
demo_theme = gr.themes.Soft( | |
primary_hue=gr.themes.Color( | |
c50="#FFF7ED", | |
c100="#FFEDD5", | |
c200="#FED7AA", | |
c300="#FDBA74", | |
c400="#FB923C", | |
c500="#F97316", | |
c600="#EA580C", | |
c700="#C2410C", | |
c800="#9A3412", | |
c900="#7C2D12", | |
c950="#431407", | |
), | |
secondary_hue="zinc", | |
neutral_hue="zinc", | |
font=("Pretendard", "sans-serif") | |
) | |
# JPG ๋ค์ด๋ก๋ ๊ธฐ๋ฅ ๊ตฌํ | |
def save_as_jpg(file_path): | |
try: | |
if file_path is None: | |
return None | |
# ํ์ผ ๊ฒฝ๋ก๋ฅผ ๋ฐ์ PIL ์ด๋ฏธ์ง๋ก ๋ณํ | |
img = Image.open(file_path) | |
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: | |
img.save(tmp, format="JPEG") | |
tmp_path = tmp.name | |
return tmp_path # ํ์ผ ๊ฒฝ๋ก ๋ฐํ | |
except Exception as error: | |
print(f"Error saving as JPG: {error}") | |
return None | |
# Clear ํจ์: ์ ๋ ฅ ๋ฐ ์ถ๋ ฅ ์ด๊ธฐํ | |
def clear_all(): | |
return [None, None, None] | |
# Gradio Interface ๊ตฌ์ฑ | |
with gr.Blocks(theme=demo_theme, css=css) as demo: | |
with gr.Row(): | |
# ์ผ์ชฝ ์น์ : ์ ๋ ฅ | |
with gr.Column(scale=1): | |
target_image = gr.Image( | |
type="filepath", | |
label="์ผ๊ตด์ ๋ณ๊ฒฝํ ์ด๋ฏธ์ง" | |
) | |
swap_image = gr.Image( | |
type="filepath", | |
label="๊ต์ฒดํ ์ผ๊ตด" | |
) | |
swap_button = gr.Button("์ผ๊ตด ๊ต์ฒด", elem_id="swap-button") | |
clear_button = gr.Button("๋ฆฌ์ ํ๊ธฐ") | |
# ์ค๋ฅธ์ชฝ ์น์ : ์ถ๋ ฅ | |
with gr.Column(scale=1): | |
result_image = gr.Image( | |
type="filepath", | |
label="๊ฒฐ๊ณผ ์ด๋ฏธ์ง" | |
) | |
download_jpg_button = gr.Button("JPG๋ก ๋ณํํ๊ธฐ", elem_id="download-button") | |
download_file = gr.File(label="JPG ์ด๋ฏธ์ง ๋ค์ด๋ฐ๊ธฐ") | |
# ๋ฒํผ ํด๋ฆญ ์ ์์ธก ํจ์ ํธ์ถ | |
swap_button.click( | |
fn=predictor.predict, | |
inputs=[target_image, swap_image], | |
outputs=result_image | |
) | |
# ๋ฆฌ์ ํ๊ธฐ ๋ฒํผ ํด๋ฆญ ์ ์ ๋ ฅ ๋ฐ ์ถ๋ ฅ ์ด๋ฏธ์ง ์ด๊ธฐํ | |
clear_button.click( | |
fn=clear_all, | |
inputs=None, | |
outputs=[target_image, swap_image, result_image] | |
) | |
# JPG ๋ค์ด๋ก๋ ๋ฒํผ ํด๋ฆญ ์ ํ์ผ ์์ฑ | |
download_jpg_button.click( | |
fn=save_as_jpg, | |
inputs=result_image, | |
outputs=download_file | |
) | |
# Gradio Interface ์คํ | |
demo.launch() |