|
import os |
|
import subprocess |
|
|
|
os.system("pip install gradio==3.50") |
|
os.system("pip install dlib==19.24.2") |
|
|
|
|
|
|
|
import torch |
|
print(f"Is CUDA available: {torch.cuda.is_available()}") |
|
|
|
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") |
|
|
|
|
|
|
|
|
|
from argparse import Namespace |
|
import pprint |
|
import numpy as np |
|
from PIL import Image |
|
import torch |
|
import torchvision.transforms as transforms |
|
import cv2 |
|
import dlib |
|
import matplotlib.pyplot as plt |
|
import gradio as gr |
|
import tensorflow as tf |
|
from tensorflow.keras.models import load_model |
|
from tensorflow.keras.losses import MeanSquaredError |
|
from tensorflow.keras.preprocessing.image import img_to_array |
|
from huggingface_hub import hf_hub_download, login |
|
from datasets.augmentations import AgeTransformer |
|
from utils.common import tensor2im |
|
from models.psp import pSp |
|
|
|
|
|
login(token=os.getenv("TOKENKEY")) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
age_calc_model = torch.load('Custom_Age_prediction_model.pth') |
|
|
|
|
|
|
|
|
|
age_prototxt = hf_hub_download(repo_id="AshanGimhana/Age_Detection_caffe", filename="age.prototxt") |
|
caffe_model = hf_hub_download(repo_id="AshanGimhana/Age_Detection_caffe", filename="dex_imdb_wiki.caffemodel") |
|
sam_ffhq_aging = hf_hub_download(repo_id="AshanGimhana/Face_Agin_model", filename="sam_ffhq_aging.pt") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
detector = dlib.get_frontal_face_detector() |
|
predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat") |
|
|
|
|
|
EXPERIMENT_TYPE = 'ffhq_aging' |
|
EXPERIMENT_DATA_ARGS = { |
|
"ffhq_aging": { |
|
"model_path": sam_ffhq_aging, |
|
"transform": transforms.Compose([ |
|
transforms.Resize((256, 256)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) |
|
]) |
|
} |
|
} |
|
EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[EXPERIMENT_TYPE] |
|
model_path = EXPERIMENT_ARGS['model_path'] |
|
ckpt = torch.load(model_path, map_location='cpu') |
|
opts = ckpt['opts'] |
|
pprint.pprint(opts) |
|
opts['checkpoint_path'] = model_path |
|
opts = Namespace(**opts) |
|
net = pSp(opts) |
|
net.eval() |
|
net.cuda() |
|
|
|
print('Model successfully loaded!') |
|
|
|
def check_image_quality(image): |
|
|
|
gray_image = np.array(image.convert("L")) |
|
|
|
|
|
hist = exposure.histogram(gray_image) |
|
low_exposure = hist[0][:5].sum() > 0.5 * hist[0].sum() |
|
high_exposure = hist[0][-5:].sum() > 0.5 * hist[0].sum() |
|
|
|
|
|
sharpness = cv2.Laplacian(np.array(image), cv2.CV_64F).var() |
|
low_sharpness = sharpness < 70 |
|
|
|
|
|
if low_exposure or high_exposure or low_sharpness: |
|
return False |
|
return True |
|
|
|
|
|
def get_face_region(image): |
|
gray = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2GRAY) |
|
faces = detector(gray) |
|
if len(faces) > 0: |
|
return faces[0] |
|
return None |
|
|
|
def get_mouth_region(image): |
|
gray = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2GRAY) |
|
faces = detector(gray) |
|
for face in faces: |
|
landmarks = predictor(gray, face) |
|
mouth_points = [(landmarks.part(i).x, landmarks.part(i).y) for i in range(48, 68)] |
|
return np.array(mouth_points, np.int32) |
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_age(image): |
|
age_calc_model.eval() |
|
|
|
image = cv2.imread(image, cv2.IMREAD_GRAYSCALE) |
|
image = cv2.resize(image, (64, 64)) |
|
image = image / 255.0 |
|
image = np.expand_dims(image, axis=0) |
|
image = np.expand_dims(image, axis=0) |
|
image = torch.tensor(image, dtype=torch.float32).to(device) |
|
|
|
|
|
image_tensor = torch.tensor(image, dtype=torch.float32) |
|
|
|
|
|
with torch.no_grad(): |
|
predicted_age = age_calc_model(image_tensor) |
|
|
|
return int(predicted_age.item()) |
|
|
|
|
|
def color_correct(source, target): |
|
mean_src = np.mean(source, axis=(0, 1)) |
|
std_src = np.std(source, axis=(0, 1)) |
|
mean_tgt = np.mean(target, axis=(0, 1)) |
|
std_tgt = np.std(target, axis=(0, 1)) |
|
src_normalized = (source - mean_src) / std_src |
|
src_corrected = (src_normalized * std_tgt) + mean_tgt |
|
return np.clip(src_corrected, 0, 255).astype(np.uint8) |
|
|
|
|
|
def replace_teeth(temp_image, aged_image): |
|
temp_image = np.array(temp_image) |
|
aged_image = np.array(aged_image) |
|
temp_mouth = get_mouth_region(temp_image) |
|
aged_mouth = get_mouth_region(aged_image) |
|
if temp_mouth is None or aged_mouth is None: |
|
return aged_image |
|
temp_mask = np.zeros_like(temp_image) |
|
cv2.fillConvexPoly(temp_mask, temp_mouth, (255, 255, 255)) |
|
temp_mouth_region = cv2.bitwise_and(temp_image, temp_mask) |
|
temp_mouth_bbox = cv2.boundingRect(temp_mouth) |
|
aged_mouth_bbox = cv2.boundingRect(aged_mouth) |
|
temp_mouth_crop = temp_mouth_region[temp_mouth_bbox[1]:temp_mouth_bbox[1] + temp_mouth_bbox[3], temp_mouth_bbox[0]:temp_mouth_bbox[0] + temp_mouth_bbox[2]] |
|
temp_mask_crop = temp_mask[temp_mouth_bbox[1]:temp_mouth_bbox[1] + temp_mouth_bbox[3], temp_mouth_bbox[0]:temp_mouth_bbox[0] + temp_mouth_bbox[2]] |
|
temp_mouth_crop_resized = cv2.resize(temp_mouth_crop, (aged_mouth_bbox[2], aged_mouth_bbox[3])) |
|
temp_mask_crop_resized = cv2.resize(temp_mask_crop, (aged_mouth_bbox[2], aged_mouth_bbox[3])) |
|
aged_mouth_crop = aged_image[aged_mouth_bbox[1]:aged_mouth_bbox[1] + aged_mouth_bbox[3], aged_mouth_bbox[0]:aged_mouth_bbox[0] + aged_mouth_bbox[2]] |
|
temp_mouth_crop_resized = color_correct(temp_mouth_crop_resized, aged_mouth_crop) |
|
center = (aged_mouth_bbox[0] + aged_mouth_bbox[2] // 2, aged_mouth_bbox[1] + aged_mouth_bbox[3] // 2) |
|
seamless_teeth = cv2.seamlessClone(temp_mouth_crop_resized, aged_image, temp_mask_crop_resized, center, cv2.NORMAL_CLONE) |
|
return seamless_teeth |
|
|
|
|
|
def run_alignment(image): |
|
from scripts.align_all_parallel import align_face |
|
temp_image_path = "/tmp/temp_image.jpg" |
|
image.save(temp_image_path) |
|
aligned_image = align_face(filepath=temp_image_path, predictor=predictor) |
|
return aligned_image |
|
|
|
|
|
def apply_aging(image, target_age): |
|
img_transforms = EXPERIMENT_DATA_ARGS[EXPERIMENT_TYPE]['transform'] |
|
input_image = img_transforms(image) |
|
age_transformers = [AgeTransformer(target_age=target_age)] |
|
results = [] |
|
for age_transformer in age_transformers: |
|
with torch.no_grad(): |
|
input_image_age = [age_transformer(input_image.cpu()).to('cuda')] |
|
input_image_age = torch.stack(input_image_age) |
|
result_tensor = net(input_image_age.float(), randomize_noise=False, resize=False)[0] |
|
result_image = tensor2im(result_tensor) |
|
results.append(np.array(result_image)) |
|
final_result = results[0] |
|
return final_result |
|
|
|
|
|
def process_image(uploaded_image): |
|
|
|
temp_images_good = [Image.open(f"good_teeth/G{i}.JPG") for i in range(1, 4)] |
|
temp_images_bad = [Image.open(f"bad_teeth/B{i}.jpeg") for i in range(1, 5)] |
|
|
|
|
|
predicted_age = predict_age(uploaded_image) |
|
target_age = predicted_age + 5 |
|
|
|
|
|
aligned_image = run_alignment(uploaded_image) |
|
|
|
|
|
aged_image = apply_aging(aligned_image, target_age=target_age) |
|
|
|
|
|
good_teeth_image = temp_images_good[np.random.randint(0, len(temp_images_good))] |
|
bad_teeth_image = temp_images_bad[np.random.randint(0, len(temp_images_bad))] |
|
|
|
|
|
aged_image_good_teeth = replace_teeth(good_teeth_image, aged_image) |
|
aged_image_bad_teeth = replace_teeth(bad_teeth_image, aged_image) |
|
|
|
return aged_image_good_teeth, aged_image_bad_teeth, predicted_age, target_age |
|
|
|
|
|
def show_results(uploaded_image): |
|
|
|
if not check_image_quality(uploaded_image): |
|
return None, None, "Not_Allowed" |
|
|
|
|
|
aged_image_good_teeth, aged_image_bad_teeth, predicted_age, target_age = process_image(uploaded_image) |
|
return aged_image_good_teeth, aged_image_bad_teeth, f"Predicted Age: {predicted_age}, Target Age: {target_age}" |
|
|
|
iface = gr.Interface( |
|
fn=show_results, |
|
inputs=gr.Image(type="pil"), |
|
outputs=[gr.Image(type="pil"), gr.Image(type="pil"), gr.Textbox()], |
|
title="Aging Effect with Teeth Replacement", |
|
description="Upload an image to apply an aging effect. The application will generate two results: one with good teeth and one with bad teeth." |
|
) |
|
|
|
iface.launch() |