import os import subprocess os.system("pip install gradio==3.50") os.system("pip install dlib==19.24.2") os.system("pip install scikit-learn") os.system("pip install scikit-image") os.system("pip install tensorflow==2.11.0") ############################################# import torch print(f"Is CUDA available: {torch.cuda.is_available()}") # True print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") ################################################### import tensorflow as tf gpus = tf.config.list_physical_devices('GPU') print("Available GPUs TF:", gpus) if gpus: try: # Allow TensorFlow to allocate memory as needed for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) except RuntimeError as e: print(e) else: print("No GPUs available.") ################################################### from skimage import exposure from skimage.filters import laplace from argparse import Namespace import pprint import numpy as np from PIL import Image import torch import torchvision.transforms as transforms import cv2 import dlib import matplotlib.pyplot as plt import gradio as gr # Importing Gradio as gr import tensorflow as tf from tensorflow.keras.models import load_model from tensorflow.keras.losses import MeanSquaredError from tensorflow.keras.preprocessing.image import img_to_array from huggingface_hub import hf_hub_download, login from datasets.augmentations import AgeTransformer from utils.common import tensor2im from models.psp import pSp # Huggingface login login(token=os.getenv("TOKENKEY")) ######################################################################## ############## tensorflow model for age calculation ####################### # If 'mse' is a custom function needed, #custom_objects = {'mse': MeanSquaredError()} new_age_model = load_model("age_prediction_modelV2.h5") ######################################################################## ######################################################################## ############## pytorch model for age calculation ####################### #age_calc_model = torch.load('Custom_Age_prediction_model.pth') ######################################################################## # Download models from Huggingface #age_prototxt = hf_hub_download(repo_id="AshanGimhana/Age_Detection_caffe", filename="age.prototxt") #caffe_model = hf_hub_download(repo_id="AshanGimhana/Age_Detection_caffe", filename="dex_imdb_wiki.caffemodel") ######################################################################## ############## caffe model for age calculation ####################### #age_prototxt = hf_hub_download(repo_id="AshanGimhana/Age_Detection_caffe", filename="age.prototxt") #caffe_model = hf_hub_download(repo_id="AshanGimhana/Age_Detection_caffe", filename="dex_imdb_wiki.caffemodel") # Age prediction model setup #age_net = cv2.dnn.readNetFromCaffe(age_prototxt, caffe_model) ######################################################################## ######################################################################################################## # Aging model # sam_ffhq_aging = hf_hub_download(repo_id="AshanGimhana/Face_Agin_model", filename="sam_ffhq_aging.pt") ######################################################################################################## # Face detection and landmarks predictor setup detector = dlib.get_frontal_face_detector() predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat") # Load the pretrained aging model EXPERIMENT_TYPE = 'ffhq_aging' EXPERIMENT_DATA_ARGS = { "ffhq_aging": { "model_path": sam_ffhq_aging, "transform": transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) } } EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[EXPERIMENT_TYPE] model_path = EXPERIMENT_ARGS['model_path'] ckpt = torch.load(model_path, map_location='cpu') opts = ckpt['opts'] pprint.pprint(opts) opts['checkpoint_path'] = model_path opts = Namespace(**opts) net = pSp(opts) net.eval() net.cuda() print('Model successfully loaded!') ####### Image quality checking func ###################### #def check_image_quality(image): ### Convert the image to grayscale ##gray_image = np.array(image.convert("L")) ### Check for under/over-exposure using histogram #hist = exposure.histogram(gray_image) #low_exposure = hist[0][:5].sum() > 0.5 * hist[0].sum() # Significant pixels in dark range #high_exposure = hist[0][-5:].sum() > 0.5 * hist[0].sum() # Significant pixels in bright range ### Check sharpness using Laplacian variance #sharpness = cv2.Laplacian(np.array(image), cv2.CV_64F).var() #low_sharpness = sharpness < 70 # Threshold for sharpness ### Check overall quality #if low_exposure or high_exposure or low_sharpness: #return False # Image quality is insufficient #return True # Image quality is sufficient ############################################################ # Functions for face and mouth region def get_face_region(image): gray = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2GRAY) faces = detector(gray) if len(faces) > 0: return faces[0] return None def get_mouth_region(image): gray = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2GRAY) faces = detector(gray) for face in faces: landmarks = predictor(gray, face) mouth_points = [(landmarks.part(i).x, landmarks.part(i).y) for i in range(48, 68)] return np.array(mouth_points, np.int32) return None # Function to predict age # old tensorflow function for age predict def predict_age(image): image = np.array(image.resize((64, 64))) image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) image = image / 255.0 image = np.expand_dims(image, axis=0) ##### Predict age val = new_age_model.predict(np.array(image)) age = val[0][0] return int(age) #def predict_age(image): #age_calc_model.eval() ##### Load and preprocess the image #image = cv2.imread(image, cv2.IMREAD_GRAYSCALE) # Load as grayscale #image = cv2.resize(image, (64, 64)) # Resize to 64x64 #image = image / 255.0 # Normalize pixel values to [0, 1] #image = np.expand_dims(image, axis=0) # Add batch dimension #image = np.expand_dims(image, axis=0) # Add channel dimension #image = torch.tensor(image, dtype=torch.float32).to(device) # Convert to tensor #image_tensor = torch.tensor(image, dtype=torch.float32) #### Predict age #with torch.no_grad(): #predicted_age = age_calc_model(image_tensor) #return int(predicted_age.item()) # Function for color correction def color_correct(source, target): mean_src = np.mean(source, axis=(0, 1)) std_src = np.std(source, axis=(0, 1)) mean_tgt = np.mean(target, axis=(0, 1)) std_tgt = np.std(target, axis=(0, 1)) src_normalized = (source - mean_src) / std_src src_corrected = (src_normalized * std_tgt) + mean_tgt return np.clip(src_corrected, 0, 255).astype(np.uint8) # Function to replace teeth def replace_teeth(temp_image, aged_image): temp_image = np.array(temp_image) aged_image = np.array(aged_image) temp_mouth = get_mouth_region(temp_image) aged_mouth = get_mouth_region(aged_image) if temp_mouth is None or aged_mouth is None: return aged_image temp_mask = np.zeros_like(temp_image) cv2.fillConvexPoly(temp_mask, temp_mouth, (255, 255, 255)) temp_mouth_region = cv2.bitwise_and(temp_image, temp_mask) temp_mouth_bbox = cv2.boundingRect(temp_mouth) aged_mouth_bbox = cv2.boundingRect(aged_mouth) temp_mouth_crop = temp_mouth_region[temp_mouth_bbox[1]:temp_mouth_bbox[1] + temp_mouth_bbox[3], temp_mouth_bbox[0]:temp_mouth_bbox[0] + temp_mouth_bbox[2]] temp_mask_crop = temp_mask[temp_mouth_bbox[1]:temp_mouth_bbox[1] + temp_mouth_bbox[3], temp_mouth_bbox[0]:temp_mouth_bbox[0] + temp_mouth_bbox[2]] temp_mouth_crop_resized = cv2.resize(temp_mouth_crop, (aged_mouth_bbox[2], aged_mouth_bbox[3])) temp_mask_crop_resized = cv2.resize(temp_mask_crop, (aged_mouth_bbox[2], aged_mouth_bbox[3])) aged_mouth_crop = aged_image[aged_mouth_bbox[1]:aged_mouth_bbox[1] + aged_mouth_bbox[3], aged_mouth_bbox[0]:aged_mouth_bbox[0] + aged_mouth_bbox[2]] temp_mouth_crop_resized = color_correct(temp_mouth_crop_resized, aged_mouth_crop) center = (aged_mouth_bbox[0] + aged_mouth_bbox[2] // 2, aged_mouth_bbox[1] + aged_mouth_bbox[3] // 2) seamless_teeth = cv2.seamlessClone(temp_mouth_crop_resized, aged_image, temp_mask_crop_resized, center, cv2.NORMAL_CLONE) return seamless_teeth # Function to run alignment def run_alignment(image): from scripts.align_all_parallel import align_face temp_image_path = "/tmp/temp_image.jpg" image.save(temp_image_path) aligned_image = align_face(filepath=temp_image_path, predictor=predictor) return aligned_image # Function to apply aging def apply_aging(image, target_age): img_transforms = EXPERIMENT_DATA_ARGS[EXPERIMENT_TYPE]['transform'] input_image = img_transforms(image) age_transformers = [AgeTransformer(target_age=target_age)] results = [] for age_transformer in age_transformers: with torch.no_grad(): input_image_age = [age_transformer(input_image.cpu()).to('cuda')] input_image_age = torch.stack(input_image_age) result_tensor = net(input_image_age.float(), randomize_noise=False, resize=False)[0] result_image = tensor2im(result_tensor) results.append(np.array(result_image)) final_result = results[0] return final_result # Function to process the image def process_image(uploaded_image): # Loading images for good and bad teeth temp_images_good = [Image.open(f"good_teeth/G{i}.JPG") for i in range(1, 4)] temp_images_bad = [Image.open(f"bad_teeth/B{i}.jpeg") for i in range(1, 5)] # Predicting the age predicted_age = predict_age(uploaded_image) if predicted_age >= 48: target_age =35+1 else: target_age = predicted_age + 2 # Aligning the face in the uploaded image aligned_image = run_alignment(uploaded_image) # Applying aging effect aged_image = apply_aging(aligned_image, target_age=target_age) # Randomly selecting teeth images good_teeth_image = temp_images_good[np.random.randint(0, len(temp_images_good))] bad_teeth_image = temp_images_bad[np.random.randint(0, len(temp_images_bad))] # Replacing teeth in aged image aged_image_good_teeth = replace_teeth(good_teeth_image, aged_image) aged_image_bad_teeth = replace_teeth(bad_teeth_image, aged_image) return aged_image_good_teeth, aged_image_bad_teeth, predicted_age, target_age # Gradio Interface def show_results(uploaded_image): ### Perform quality check #if not check_image_quality(uploaded_image): #return None, None, "Not_Allowed" # If quality is acceptable, continue with processing aged_image_good_teeth, aged_image_bad_teeth, predicted_age, target_age = process_image(uploaded_image) return aged_image_good_teeth, aged_image_bad_teeth, f"Predicted Age: {predicted_age}, Target Age: {target_age}" iface = gr.Interface( fn=show_results, inputs=gr.Image(type="pil"), outputs=[gr.Image(type="pil"), gr.Image(type="pil"), gr.Textbox()], title="Aging Effect with Teeth Replacement", description="Upload an image to apply an aging effect. The application will generate two results: one with good teeth and one with bad teeth." ) iface.launch()