|
import os |
|
import subprocess |
|
os.system("pip install gradio==3.50") |
|
from argparse import Namespace |
|
import pprint |
|
import numpy as np |
|
from PIL import Image |
|
import torch |
|
import torchvision.transforms as transforms |
|
import cv2 |
|
import dlibs.dlib |
|
import matplotlib.pyplot as plt |
|
import gradio as gr |
|
from tensorflow.keras.preprocessing.image import img_to_array |
|
from huggingface_hub import hf_hub_download, login |
|
from datasets.augmentations import AgeTransformer |
|
from utils.common import tensor2im |
|
from models.psp import pSp |
|
|
|
|
|
login(token=TOKENKEY) |
|
|
|
|
|
age_prototxt = hf_hub_download(repo_id="AshanGimhana/Age_Detection_caffe", filename="age.prototxt") |
|
caffe_model = hf_hub_download(repo_id="AshanGimhana/Age_Detection_caffe", filename="dex_imdb_wiki.caffemodel") |
|
sam_ffhq_aging = hf_hub_download(repo_id="AshanGimhana/Face_Agin_model", filename="sam_ffhq_aging.pt") |
|
|
|
|
|
age_net = cv2.dnn.readNetFromCaffe(age_prototxt, caffe_model) |
|
|
|
|
|
detector = dlib.get_frontal_face_detector() |
|
predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat") |
|
|
|
|
|
EXPERIMENT_TYPE = 'ffhq_aging' |
|
EXPERIMENT_DATA_ARGS = { |
|
"ffhq_aging": { |
|
"model_path": sam_ffhq_aging, |
|
"transform": transforms.Compose([ |
|
transforms.Resize((256, 256)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) |
|
]) |
|
} |
|
} |
|
EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[EXPERIMENT_TYPE] |
|
model_path = EXPERIMENT_ARGS['model_path'] |
|
ckpt = torch.load(model_path, map_location='cpu') |
|
opts = ckpt['opts'] |
|
pprint.pprint(opts) |
|
opts['checkpoint_path'] = model_path |
|
opts = Namespace(**opts) |
|
net = pSp(opts) |
|
net.eval() |
|
net.cpu() |
|
|
|
print('Model successfully loaded!') |
|
|
|
def get_face_region(image): |
|
gray = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2GRAY) |
|
faces = detector(gray) |
|
if len(faces) > 0: |
|
return faces[0] |
|
return None |
|
|
|
def get_mouth_region(image): |
|
gray = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2GRAY) |
|
faces = detector(gray) |
|
for face in faces: |
|
landmarks = predictor(gray, face) |
|
mouth_points = [(landmarks.part(i).x, landmarks.part(i).y) for i in range(48, 68)] |
|
return np.array(mouth_points, np.int32) |
|
return None |
|
|
|
def predict_age(image): |
|
image = np.array(image) |
|
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) |
|
blob = cv2.dnn.blobFromImage(image, scalefactor=1.0, size=(224, 224), mean=(104.0, 177.0, 123.0), swapRB=False) |
|
age_net.setInput(blob) |
|
predictions = age_net.forward() |
|
predicted_age = np.dot(predictions[0], np.arange(0, 101)).flatten()[0] |
|
return int(predicted_age) |
|
|
|
def color_correct(source, target): |
|
mean_src = np.mean(source, axis=(0, 1)) |
|
std_src = np.std(source, axis=(0, 1)) |
|
mean_tgt = np.mean(target, axis=(0, 1)) |
|
std_tgt = np.std(target, axis=(0, 1)) |
|
src_normalized = (source - mean_src) / std_src |
|
src_corrected = (src_normalized * std_tgt) + mean_tgt |
|
return np.clip(src_corrected, 0, 255).astype(np.uint8) |
|
|
|
def replace_teeth(temp_image, aged_image): |
|
temp_image = np.array(temp_image) |
|
aged_image = np.array(aged_image) |
|
temp_mouth = get_mouth_region(temp_image) |
|
aged_mouth = get_mouth_region(aged_image) |
|
if temp_mouth is None or aged_mouth is None: |
|
return aged_image |
|
|
|
temp_mask = np.zeros_like(temp_image) |
|
cv2.fillConvexPoly(temp_mask, temp_mouth, (255, 255, 255)) |
|
temp_mouth_region = cv2.bitwise_and(temp_image, temp_mask) |
|
temp_mouth_bbox = cv2.boundingRect(temp_mouth) |
|
aged_mouth_bbox = cv2.boundingRect(aged_mouth) |
|
temp_mouth_crop = temp_mouth_region[temp_mouth_bbox[1]:temp_mouth_bbox[1] + temp_mouth_bbox[3], temp_mouth_bbox[0]:temp_mouth_bbox[0] + temp_mouth_bbox[2]] |
|
temp_mask_crop = temp_mask[temp_mouth_bbox[1]:temp_mouth_bbox[1] + temp_mouth_bbox[3], temp_mouth_bbox[0]:temp_mouth_bbox[0] + temp_mouth_bbox[2]] |
|
temp_mouth_crop_resized = cv2.resize(temp_mouth_crop, (aged_mouth_bbox[2], aged_mouth_bbox[3])) |
|
temp_mask_crop_resized = cv2.resize(temp_mask_crop, (aged_mouth_bbox[2], aged_mouth_bbox[3])) |
|
aged_mouth_crop = aged_image[aged_mouth_bbox[1]:aged_mouth_bbox[1] + aged_mouth_bbox[3], aged_mouth_bbox[0]:aged_mouth_bbox[0] + aged_mouth_bbox[2]] |
|
temp_mouth_crop_resized = color_correct(temp_mouth_crop_resized, aged_mouth_crop) |
|
center = (aged_mouth_bbox[0] + aged_mouth_bbox[2] // 2, aged_mouth_bbox[1] + aged_mouth_bbox[3] // 2) |
|
seamless_teeth = cv2.seamlessClone(temp_mouth_crop_resized, aged_image, temp_mask_crop_resized, center, cv2.NORMAL_CLONE) |
|
return seamless_teeth |
|
|
|
def run_alignment(image): |
|
from scripts.align_all_parallel import align_face |
|
temp_image_path = "/tmp/temp_image.jpg" |
|
image.save(temp_image_path) |
|
aligned_image = align_face(filepath=temp_image_path, predictor=predictor) |
|
return aligned_image |
|
|
|
def apply_aging(image, target_age): |
|
img_transforms = EXPERIMENT_DATA_ARGS[EXPERIMENT_TYPE]['transform'] |
|
input_image = img_transforms(image) |
|
age_transformers = [AgeTransformer(target_age=target_age)] |
|
results = [] |
|
for age_transformer in age_transformers: |
|
with torch.no_grad(): |
|
input_image_age = [age_transformer(input_image.cpu()).to('cpu')] |
|
input_image_age = torch.stack(input_image_age) |
|
result_tensor = net(input_image_age.to("cpu").float(), randomize_noise=False, resize=False)[0] |
|
result_image = tensor2im(result_tensor) |
|
results.append(np.array(result_image)) |
|
final_result = results[0] |
|
return final_result |
|
|
|
def process_image(uploaded_image): |
|
|
|
temp_images_good = [Image.open(f"good_teeth/G{i}.JPG") for i in range(1, 5)] |
|
temp_images_bad = [Image.open(f"bad_teeth/B{i}.jpeg") for i in range(1, 5)] |
|
|
|
|
|
predicted_age = predict_age(uploaded_image) |
|
target_age = predicted_age + 5 |
|
|
|
|
|
aligned_image = run_alignment(uploaded_image) |
|
|
|
|
|
aged_image = apply_aging(aligned_image, target_age=target_age) |
|
|
|
|
|
good_teeth_image = temp_images_good[np.random.randint(0, len(temp_images_good))] |
|
bad_teeth_image = temp_images_bad[np.random.randint(0, len(temp_images_bad))] |
|
|
|
|
|
aged_image_good_teeth = replace_teeth(good_teeth_image, aged_image) |
|
aged_image_bad_teeth = replace_teeth(bad_teeth_image, aged_image) |
|
|
|
return aged_image_good_teeth, aged_image_bad_teeth |
|
|
|
iface = gr.Interface( |
|
fn=process_image, |
|
inputs=gr.Image(type="pil"), |
|
outputs=[gr.Image(type="pil"), gr.Image(type="pil")], |
|
title="Aging Effect with Teeth Replacement", |
|
description="Upload an image to apply an aging effect. The application will generate two results: one with good teeth and one with bad teeth." |
|
) |
|
|
|
iface.launch(debug=True) |
|
|