import os import subprocess os.system("pip install gradio==3.50") os.system("pip install dlib==19.24.2") ############################################# import torch print(f"Is CUDA available: {torch.cuda.is_available()}") # True print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") ################################################### from argparse import Namespace import pprint import numpy as np from PIL import Image import torch import torchvision.transforms as transforms import cv2 import dlib import matplotlib.pyplot as plt import gradio as gr # Importing Gradio as gr import tensorflow as tf from tensorflow.keras.models import load_model from tensorflow.keras.losses import MeanSquaredError from tensorflow.keras.preprocessing.image import img_to_array from huggingface_hub import hf_hub_download, login from datasets.augmentations import AgeTransformer from utils.common import tensor2im from models.psp import pSp # Huggingface login login(token=os.getenv("TOKENKEY")) # If 'mse' is a custom function needed, #custom_objects = {'mse': MeanSquaredError()} #new_age_model = load_model("age_prediction_model.h5") # Download models from Huggingface age_prototxt = hf_hub_download(repo_id="AshanGimhana/Age_Detection_caffe", filename="age.prototxt") caffe_model = hf_hub_download(repo_id="AshanGimhana/Age_Detection_caffe", filename="dex_imdb_wiki.caffemodel") sam_ffhq_aging = hf_hub_download(repo_id="AshanGimhana/Face_Agin_model", filename="sam_ffhq_aging.pt") # Age prediction model setup age_net = cv2.dnn.readNetFromCaffe(age_prototxt, caffe_model) # Face detection and landmarks predictor setup detector = dlib.get_frontal_face_detector() predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat") # Load the pretrained aging model EXPERIMENT_TYPE = 'ffhq_aging' EXPERIMENT_DATA_ARGS = { "ffhq_aging": { "model_path": sam_ffhq_aging, "transform": transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) } } EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[EXPERIMENT_TYPE] model_path = EXPERIMENT_ARGS['model_path'] ckpt = torch.load(model_path, map_location='cpu') opts = ckpt['opts'] pprint.pprint(opts) opts['checkpoint_path'] = model_path opts = Namespace(**opts) net = pSp(opts) net.eval() net.cuda() print('Model successfully loaded!') def check_image_quality(image): # Convert the image to grayscale gray_image = np.array(image.convert("L")) # Check for under/over-exposure using histogram hist = exposure.histogram(gray_image) low_exposure = hist[0][:5].sum() > 0.5 * hist[0].sum() # Significant pixels in dark range high_exposure = hist[0][-5:].sum() > 0.5 * hist[0].sum() # Significant pixels in bright range # Check sharpness using Laplacian variance sharpness = cv2.Laplacian(np.array(image), cv2.CV_64F).var() low_sharpness = sharpness < 70 # Threshold for sharpness # Check overall quality if low_exposure or high_exposure or low_sharpness: return False # Image quality is insufficient return True # Image quality is sufficient # Functions for face and mouth region def get_face_region(image): gray = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2GRAY) faces = detector(gray) if len(faces) > 0: return faces[0] return None def get_mouth_region(image): gray = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2GRAY) faces = detector(gray) for face in faces: landmarks = predictor(gray, face) mouth_points = [(landmarks.part(i).x, landmarks.part(i).y) for i in range(48, 68)] return np.array(mouth_points, np.int32) return None # Function to predict age def predict_age(image): image = np.array(image.resize((64, 64))) image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) image = image / 255.0 image = np.expand_dims(image, axis=0) # Predict age val = new_age_model.predict(np.array(image)) age = val[0][0] return int(age) # Function for color correction def color_correct(source, target): mean_src = np.mean(source, axis=(0, 1)) std_src = np.std(source, axis=(0, 1)) mean_tgt = np.mean(target, axis=(0, 1)) std_tgt = np.std(target, axis=(0, 1)) src_normalized = (source - mean_src) / std_src src_corrected = (src_normalized * std_tgt) + mean_tgt return np.clip(src_corrected, 0, 255).astype(np.uint8) # Function to replace teeth def replace_teeth(temp_image, aged_image): temp_image = np.array(temp_image) aged_image = np.array(aged_image) temp_mouth = get_mouth_region(temp_image) aged_mouth = get_mouth_region(aged_image) if temp_mouth is None or aged_mouth is None: return aged_image temp_mask = np.zeros_like(temp_image) cv2.fillConvexPoly(temp_mask, temp_mouth, (255, 255, 255)) temp_mouth_region = cv2.bitwise_and(temp_image, temp_mask) temp_mouth_bbox = cv2.boundingRect(temp_mouth) aged_mouth_bbox = cv2.boundingRect(aged_mouth) temp_mouth_crop = temp_mouth_region[temp_mouth_bbox[1]:temp_mouth_bbox[1] + temp_mouth_bbox[3], temp_mouth_bbox[0]:temp_mouth_bbox[0] + temp_mouth_bbox[2]] temp_mask_crop = temp_mask[temp_mouth_bbox[1]:temp_mouth_bbox[1] + temp_mouth_bbox[3], temp_mouth_bbox[0]:temp_mouth_bbox[0] + temp_mouth_bbox[2]] temp_mouth_crop_resized = cv2.resize(temp_mouth_crop, (aged_mouth_bbox[2], aged_mouth_bbox[3])) temp_mask_crop_resized = cv2.resize(temp_mask_crop, (aged_mouth_bbox[2], aged_mouth_bbox[3])) aged_mouth_crop = aged_image[aged_mouth_bbox[1]:aged_mouth_bbox[1] + aged_mouth_bbox[3], aged_mouth_bbox[0]:aged_mouth_bbox[0] + aged_mouth_bbox[2]] temp_mouth_crop_resized = color_correct(temp_mouth_crop_resized, aged_mouth_crop) center = (aged_mouth_bbox[0] + aged_mouth_bbox[2] // 2, aged_mouth_bbox[1] + aged_mouth_bbox[3] // 2) seamless_teeth = cv2.seamlessClone(temp_mouth_crop_resized, aged_image, temp_mask_crop_resized, center, cv2.NORMAL_CLONE) return seamless_teeth # Function to run alignment def run_alignment(image): from scripts.align_all_parallel import align_face temp_image_path = "/tmp/temp_image.jpg" image.save(temp_image_path) aligned_image = align_face(filepath=temp_image_path, predictor=predictor) return aligned_image # Function to apply aging def apply_aging(image, target_age): img_transforms = EXPERIMENT_DATA_ARGS[EXPERIMENT_TYPE]['transform'] input_image = img_transforms(image) age_transformers = [AgeTransformer(target_age=target_age)] results = [] for age_transformer in age_transformers: with torch.no_grad(): input_image_age = [age_transformer(input_image.cpu()).to('cuda')] input_image_age = torch.stack(input_image_age) result_tensor = net(input_image_age.float(), randomize_noise=False, resize=False)[0] result_image = tensor2im(result_tensor) results.append(np.array(result_image)) final_result = results[0] return final_result # Function to process the image def process_image(uploaded_image): # Loading images for good and bad teeth temp_images_good = [Image.open(f"good_teeth/G{i}.JPG") for i in range(1, 4)] temp_images_bad = [Image.open(f"bad_teeth/B{i}.jpeg") for i in range(1, 5)] # Predicting the age predicted_age = predict_age(uploaded_image) target_age = predicted_age + 5 # Aligning the face in the uploaded image aligned_image = run_alignment(uploaded_image) # Applying aging effect aged_image = apply_aging(aligned_image, target_age=target_age) # Randomly selecting teeth images good_teeth_image = temp_images_good[np.random.randint(0, len(temp_images_good))] bad_teeth_image = temp_images_bad[np.random.randint(0, len(temp_images_bad))] # Replacing teeth in aged image aged_image_good_teeth = replace_teeth(good_teeth_image, aged_image) aged_image_bad_teeth = replace_teeth(bad_teeth_image, aged_image) return aged_image_good_teeth, aged_image_bad_teeth, predicted_age, target_age # Gradio Interface def show_results(uploaded_image): # Perform quality check if not check_image_quality(uploaded_image): return None, None, "Not_Allowed" # If quality is acceptable, continue with processing aged_image_good_teeth, aged_image_bad_teeth, predicted_age, target_age = process_image(uploaded_image) return aged_image_good_teeth, aged_image_bad_teeth, f"Predicted Age: {predicted_age}, Target Age: {target_age}" iface = gr.Interface( fn=show_results, inputs=gr.Image(type="pil"), outputs=[gr.Image(type="pil"), gr.Image(type="pil"), gr.Textbox()], title="Aging Effect with Teeth Replacement", description="Upload an image to apply an aging effect. The application will generate two results: one with good teeth and one with bad teeth." ) iface.launch()