Spaces:
Running
Running
import gradio as gr | |
from PIL import Image, ImageDraw | |
import numpy as np | |
import torch | |
from transformers import YolosImageProcessor, YolosForObjectDetection | |
import mediapipe as mp | |
import math | |
import os # For potential future environment variable use | |
# --- Model Initialization --- | |
# 1. Face Detection Model | |
print("Attempting to load face detection model...") | |
PRIMARY_DETECTION_MODEL_NAME = "hustvl/yolos-face" | |
FALLBACK_DETECTION_MODEL_NAME = "hustvl/yolos-tiny" # Detects 'person' | |
FACE_LABEL_ID = -1 # Will be set based on which model loads | |
face_image_processor = None | |
face_detection_model = None | |
try: | |
print(f"Trying primary model: {PRIMARY_DETECTION_MODEL_NAME}") | |
face_image_processor = YolosImageProcessor.from_pretrained(PRIMARY_DETECTION_MODEL_NAME) | |
face_detection_model = YolosForObjectDetection.from_pretrained(PRIMARY_DETECTION_MODEL_NAME) | |
# For hustvl/yolos-face, the label for "face" is 0. | |
FACE_LABEL_ID = 0 # Corresponds to "face" | |
print(f"Successfully loaded primary face detection model: {PRIMARY_DETECTION_MODEL_NAME} (label 'face': {FACE_LABEL_ID})") | |
except Exception as e: | |
print(f"Error loading primary model {PRIMARY_DETECTION_MODEL_NAME}: {e}") | |
print(f"Attempting to load fallback model: {FALLBACK_DETECTION_MODEL_NAME}") | |
try: | |
face_image_processor = YolosImageProcessor.from_pretrained(FALLBACK_DETECTION_MODEL_NAME) | |
face_detection_model = YolosForObjectDetection.from_pretrained(FALLBACK_DETECTION_MODEL_NAME) | |
# For hustvl/yolos-tiny (trained on COCO), 'person' is label 0. | |
FACE_LABEL_ID = 0 # We will use 'person' (label 0) as a proxy for face | |
print(f"Successfully loaded fallback detection model: {FALLBACK_DETECTION_MODEL_NAME} (using label 'person': {FACE_LABEL_ID})") | |
except Exception as e2: | |
print(f"Error loading fallback model {FALLBACK_DETECTION_MODEL_NAME}: {e2}") | |
print("!!! CRITICAL: Face detection model could not be loaded. The app might not function correctly. !!!") | |
# face_image_processor and face_detection_model will remain None | |
# 2. Facial Landmark Model (MediaPipe Face Mesh) | |
print("Initializing MediaPipe Face Mesh...") | |
mp_face_mesh = None | |
face_mesh_detector = None | |
mp_drawing = None | |
drawing_spec = None | |
try: | |
mp_face_mesh = mp.solutions.face_mesh | |
face_mesh_detector = mp_face_mesh.FaceMesh( | |
static_image_mode=True, | |
max_num_faces=1, | |
refine_landmarks=True, | |
min_detection_confidence=0.5) | |
mp_drawing = mp.solutions.drawing_utils # For drawing landmarks | |
drawing_spec = mp_drawing.DrawingSpec(thickness=1, circle_radius=1, color=(0,255,0)) # Green dots | |
print("MediaPipe Face Mesh initialized successfully.") | |
except Exception as e: | |
print(f"Error initializing MediaPipe Face Mesh: {e}") | |
# Variables will remain None | |
# --- Helper Functions --- | |
def detect_face_local(image_pil): | |
if not face_image_processor or not face_detection_model or FACE_LABEL_ID == -1: | |
return None, "Face detection model not loaded or configured properly." | |
print(f"Detecting face with FACE_LABEL_ID: {FACE_LABEL_ID}") | |
detection_threshold = 0.4 # <<-- TRY LOWERING THIS (e.g., 0.5, 0.4, 0.3) | |
print(f"Using detection threshold: {detection_threshold}") | |
try: | |
inputs = face_image_processor(images=image_pil, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = face_detection_model(**inputs) | |
target_sizes = torch.tensor([image_pil.size[::-1]]) | |
# Setting a lower threshold for post-processing here | |
results = face_image_processor.post_process_object_detection( | |
outputs, threshold=detection_threshold, target_sizes=target_sizes | |
)[0] | |
best_box = None | |
max_score = 0 # We will still pick the best one above the (now lower) threshold | |
print(f"Detection results: {len(results['scores'])} detections before filtering by label.") | |
detected_items_for_label = [] | |
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): | |
current_score = score.item() | |
current_label = label.item() | |
print(f" - Detected item: Label {current_label}, Score {current_score:.2f}") | |
if current_label == FACE_LABEL_ID: | |
detected_items_for_label.append({'score': current_score, 'box': box.tolist()}) | |
if current_score > max_score: | |
max_score = current_score | |
best_box = box.tolist() | |
print(f"Found {len(detected_items_for_label)} items matching FACE_LABEL_ID {FACE_LABEL_ID} with scores: {[item['score'] for item in detected_items_for_label]}") | |
if best_box: | |
print(f"Selected best box with score: {max_score:.2f}") | |
# Add a small padding to the bounding box | |
padding_w = (best_box[2] - best_box[0]) * 0.15 # 15% padding width | |
padding_h = (best_box[3] - best_box[1]) * 0.15 # 15% padding height | |
xmin = max(0, best_box[0] - padding_w) | |
ymin = max(0, best_box[1] - padding_h) | |
xmax = min(image_pil.width, best_box[2] + padding_w) | |
ymax = min(image_pil.height, best_box[3] + padding_h) | |
# Ensure cropped dimensions are valid | |
if xmax <= xmin or ymax <= ymin: | |
print(f"Warning: Invalid crop dimensions after padding. Original box: {best_box}. Padded: ({xmin},{ymin},{xmax},{ymax})") | |
# Fallback to original box if padding made it invalid | |
xmin, ymin, xmax, ymax = best_box[0], best_box[1], best_box[2], best_box[3] | |
if xmax <= xmin or ymax <= ymin: # If original box itself is invalid | |
return None, "Detected box has invalid dimensions." | |
cropped_image = image_pil.crop((xmin, ymin, xmax, ymax)) | |
return cropped_image, None | |
else: | |
if len(detected_items_for_label) > 0: | |
return None, f"Faces detected but scores too low (max score: {max_score:.2f} with threshold {detection_threshold}). Try a clearer image or different pose." | |
else: | |
return None, f"No face/person detected with sufficient confidence (threshold {detection_threshold}). Ensure face is clear and well-lit." | |
except Exception as e: | |
print(f"Error during local face detection: {e}") | |
import traceback | |
traceback.print_exc() # Print full traceback for debugging | |
return None, f"Error during face detection: {str(e)}" | |
def get_landmarks_and_draw(image_pil): | |
if not face_mesh_detector or not mp_drawing or not drawing_spec: | |
return None, "MediaPipe Face Mesh not initialized for landmarks.", image_pil | |
image_rgb_mp = np.array(image_pil.convert('RGB')) # MediaPipe prefers RGB | |
results = face_mesh_detector.process(image_rgb_mp) | |
annotated_image_pil = image_pil.copy() | |
if results.multi_face_landmarks: | |
landmarks = results.multi_face_landmarks[0] | |
image_np_to_draw = np.array(annotated_image_pil) | |
# Draw landmarks using MediaPipe's utility | |
mp_drawing.draw_landmarks( | |
image=image_np_to_draw, | |
landmark_list=landmarks, | |
connections=mp_face_mesh.FACEMESH_TESSELATION, # Shows mesh | |
landmark_drawing_spec=drawing_spec, | |
connection_drawing_spec=drawing_spec) | |
annotated_image_pil = Image.fromarray(image_np_to_draw) | |
return landmarks, None, annotated_image_pil | |
else: | |
return None, "Could not detect facial landmarks.", annotated_image_pil | |
def _distance_2d_normalized(p1, p2): | |
return math.sqrt((p1.x - p2.x)**2 + (p1.y - p2.y)**2) | |
def estimate_face_shape_from_landmarks_v2(landmarks, img_width, img_height): | |
if not landmarks: | |
return "Unknown", {} | |
p_forehead_top_center = landmarks.landmark[10] | |
p_chin_bottom = landmarks.landmark[152] | |
face_height = abs(p_forehead_top_center.y - p_chin_bottom.y) | |
p_cheek_left = landmarks.landmark[234] | |
p_cheek_right = landmarks.landmark[454] | |
face_width_cheeks = abs(p_cheek_left.x - p_cheek_right.x) | |
p_forehead_L = landmarks.landmark[70] | |
p_forehead_R = landmarks.landmark[300] | |
forehead_width = abs(p_forehead_L.x - p_forehead_R.x) | |
p_jaw_angle_L = landmarks.landmark[172] | |
p_jaw_angle_R = landmarks.landmark[397] | |
jaw_width_gonial = abs(p_jaw_angle_L.x - p_jaw_angle_R.x) | |
p_chin_width_L = landmarks.landmark[143] | |
p_chin_width_R = landmarks.landmark[372] | |
chin_width = abs(p_chin_width_L.x - p_chin_width_R.x) | |
measurements = { | |
"face_height_norm": face_height, | |
"face_width_cheeks_norm": face_width_cheeks, | |
"forehead_width_norm": forehead_width, | |
"jaw_width_gonial_norm": jaw_width_gonial, | |
"chin_width_norm": chin_width | |
} | |
# print("Normalized Measurements:", {k: round(v,3) for k,v in measurements.items()}) | |
if face_width_cheeks == 0: return "Unknown (div zero)", measurements | |
facial_index = face_height / face_width_cheeks if face_width_cheeks > 0 else 0 | |
forehead_to_cheek_ratio = forehead_width / face_width_cheeks | |
jaw_to_cheek_ratio = jaw_width_gonial / face_width_cheeks | |
shape = "Unknown" | |
if facial_index > 1.05: # Longer than wide | |
if forehead_to_cheek_ratio > 0.85 and jaw_to_cheek_ratio > 0.85 and abs(forehead_width - jaw_width_gonial) < forehead_width * 0.20 : | |
shape = "Long/Oblong" # All widths relatively similar but face is long | |
elif forehead_width > jaw_width_gonial and chin_width < jaw_width_gonial * 0.85: | |
shape = "Heart/Inverted Triangle" | |
else: | |
shape = "Long" | |
elif facial_index < 0.95: # Wider than long, or close to equal width/height and not distinctly Diamond/Heart | |
if forehead_to_cheek_ratio > 0.85 and jaw_to_cheek_ratio > 0.85 and abs(forehead_width - jaw_width_gonial) < forehead_width * 0.20: | |
if jaw_width_gonial > face_width_cheeks * 0.88: # Strong jaw compared to cheeks | |
shape = "Square" | |
else: | |
shape = "Round" | |
else: # If widths are not all similar, default to Round for wider faces | |
shape = "Round" | |
else: # facial_index between 0.95 and 1.05 (balanced height/width) | |
if face_width_cheeks > forehead_width and face_width_cheeks > jaw_width_gonial and chin_width < jaw_width_gonial * 0.85: | |
shape = "Diamond" | |
elif forehead_width > jaw_width_gonial and face_width_cheeks > jaw_width_gonial and chin_width < jaw_width_gonial * 0.8: | |
if 0.80 < forehead_to_cheek_ratio < 1.0 and jaw_to_cheek_ratio < forehead_to_cheek_ratio * 0.95: | |
shape = "Oval" | |
else: | |
shape = "Heart" | |
elif abs(forehead_width - jaw_width_gonial) < forehead_width * 0.15 and abs(face_width_cheeks - forehead_width) < forehead_width * 0.15 : | |
shape = "Square" | |
else: | |
shape = "Oval" # General fallback for balanced faces not matching other criteria | |
if shape == "Unknown": # If no specific rules matched strongly | |
if 0.95 <= facial_index <= 1.05 and forehead_to_cheek_ratio < 1.0 and jaw_to_cheek_ratio < forehead_to_cheek_ratio: | |
shape = "Oval (Default)" | |
elif facial_index < 0.95: | |
shape = "Round (Default)" | |
else: | |
shape = "Long (Default)" | |
return shape, measurements | |
def get_side_profile_assessment(side_image_pil): | |
if not side_image_pil: | |
return "Not provided", None | |
# Convert Gradio Image (numpy array) to PIL Image if it's not already | |
if isinstance(side_image_pil, np.ndarray): | |
side_image_pil = Image.fromarray(side_image_pil) | |
side_image_pil = side_image_pil.convert("RGB") | |
landmarks, error_msg_lm, _ = get_landmarks_and_draw(side_image_pil) | |
if error_msg_lm or not landmarks: | |
return f"Could not analyze ({error_msg_lm or 'no landmarks'})", None | |
# Basic assessment placeholder | |
# E.g. Chin prominence (landmark 152's x vs jaw angle 172's x) | |
# This is highly dependent on consistent side view and requires careful calibration | |
# For now, just acknowledge landmarks were found | |
return "Analyzed (basic landmark detection)", landmarks | |
def get_hairstyle_suggestions_v2(face_shape, side_profile_info=""): | |
base_suggestions = { | |
"Oval": {"hair": ["Most styles work. Consider layers, textured crops, or side parts."], "beard": ["Versatile. Classic full beard, short boxed, or stubble."]}, | |
"Oval (Default)": {"hair": ["Versatile. Try layers or a textured crop. Side parts can be flattering."], "beard": ["Well-groomed stubble or a short boxed beard."]}, | |
"Long/Oblong": {"hair": ["Add width: Curls, waves, shoulder-length with layers. Bangs (blunt/side-swept). Avoid height."], "beard": ["Fuller on cheeks: full beard, mutton chops. Avoid long, pointy beards."]}, | |
"Long": {"hair": ["Add width: Curls, waves, shoulder-length with layers. Bangs (blunt/side-swept). Avoid height."], "beard": ["Fuller on cheeks: full beard, mutton chops. Avoid long, pointy beards."]}, | |
"Long (Default)": {"hair": ["Add width: Curls, waves, shoulder-length with layers. Bangs (blunt/side-swept). Avoid height."], "beard": ["Fuller on cheeks: full beard, mutton chops. Avoid long, pointy beards."]}, | |
"Heart": {"hair": ["Add jawline volume: chin-length bobs, layered shoulder cuts. Side-swept bangs/textured fringe for forehead."], "beard": ["Fuller beards to add jaw width: Garibaldi, full beard carefully shaped."]}, | |
"Heart/Inverted Triangle": {"hair": ["Add jawline volume: chin-length bobs, layered shoulder cuts. Side-swept bangs for forehead."], "beard": ["Fuller beards to add jaw width: Garibaldi, full beard shaped."]}, | |
"Square": {"hair": ["Softer styles: waves, curls, layers. Textured cuts, off-center parts. Avoid sharp, geometric cuts if aiming to soften."], "beard": ["Circle beard, rounded full beard. Stubble can highlight jaw if desired."]}, | |
"Round": {"hair": ["Add height and length: pompadour, quiff, faux hawk, side part. Layers. Avoid blunt bobs at chin or very short, round cuts."], "beard": ["Add length to chin: goatee, soul patch, beard shorter on sides & longer at chin (ducktail)."]}, | |
"Round (Default)": {"hair": ["Add height and length: pompadour, quiff, faux hawk, side part. Layers. Avoid blunt bobs at chin or very short, round cuts."], "beard": ["Add length to chin: goatee, soul patch, beard shorter on sides & longer at chin (ducktail)."]}, | |
"Diamond": {"hair": ["Soften forehead & jaw: chin bobs, shoulder length with layers, textured fringe. Side-swept bangs."], "beard": ["Fuller at chin, possibly some width at jaw but not cheeks: Balbo, shorter full beard."]}, | |
"Unknown": {"hair": ["Upload a clearer image for analysis."], "beard": ["Upload a clearer image for analysis."]}, | |
"Unknown (div zero)": {"hair": ["Measurement error. Try different image."], "beard": ["Measurement error. Try different image."]}, | |
} | |
sugg = base_suggestions.get(face_shape, {"hair": ["General advice: consult a professional stylist."], "beard": ["Experiment with styles that you feel confident in."]}) | |
hair_sug = "\n".join([f"- {s}" for s in sugg["hair"]]) | |
beard_sug = "\n".join([f"- {s}" for s in sugg["beard"]]) | |
side_note = "" | |
if "Analyzed" in side_profile_info: | |
side_note = "\n\n*Side profile analyzed. Future versions could use this for more tailored advice (e.g., jawline definition).*" | |
elif "Not provided" not in side_profile_info and side_profile_info: # If there was an attempt but it failed | |
side_note = f"\n\n*Side profile: {side_profile_info}*" | |
return f"**Haircut Suggestions for {face_shape} Face:**\n{hair_sug}\n\n**Beard Style Suggestions for {face_shape} Face:**\n{beard_sug}{side_note}" | |
def analyze_face_and_suggest_v2(front_image_input, side_image_input_optional): | |
if front_image_input is None: | |
return None, "Please upload a front-facing photo.", "" | |
# Ensure models are loaded | |
if not face_detection_model or not face_mesh_detector: | |
error_msg = [] | |
if not face_detection_model: error_msg.append("Face detector not loaded.") | |
if not face_mesh_detector: error_msg.append("Landmark detector not loaded.") | |
return None, " ".join(error_msg) + " Please check Space logs.", "" | |
img_pil = Image.fromarray(front_image_input).convert("RGB") | |
cropped_face_pil, error_msg_detect = detect_face_local(img_pil) | |
if error_msg_detect: | |
return None, error_msg_detect, "" # No measurements if face detection fails | |
if cropped_face_pil is None: | |
return None, "Could not detect a face.", "" | |
landmarks, error_msg_lm, face_with_landmarks_pil = get_landmarks_and_draw(cropped_face_pil) | |
if error_msg_lm: | |
return face_with_landmarks_pil, f"Face detected. Error getting landmarks: {error_msg_lm}", "Cannot suggest hairstyles without landmark analysis." | |
img_w, img_h = cropped_face_pil.size | |
estimated_shape, measurements = estimate_face_shape_from_landmarks_v2(landmarks, img_w, img_h) | |
measurements_str = "\n".join([f"- {k.replace('_norm',' (norm. ratio)'):<25}: {v:.3f}" for k,v in measurements.items()]) | |
analysis_text = f"Estimated Face Shape: **{estimated_shape}**\n\nNormalized Measurements:\n{measurements_str}" | |
side_profile_status = "Not provided" | |
if side_image_input_optional is not None: | |
# Pass the numpy array directly | |
side_profile_status, _ = get_side_profile_assessment(side_image_input_optional) | |
analysis_text += f"\n\nSide Profile: {side_profile_status}" | |
suggestions_text = get_hairstyle_suggestions_v2(estimated_shape, side_profile_status) | |
return face_with_landmarks_pil, analysis_text, suggestions_text | |
# --- Gradio Interface --- | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# ✂️ AI Hairstyle & Beard Suggester 🧔") | |
gr.Markdown( | |
"Upload a clear, front-facing photo. Optionally, upload a side profile." | |
"\n*Disclaimer: This app uses local AI models for face detection and landmark-based shape estimation. Suggestions are general and based on heuristics.*" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
front_image_input = gr.Image(type="numpy", label="Front Face Photo (Required)", sources=["upload", "webcam"]) | |
side_image_input = gr.Image(type="numpy", label="Side Profile Photo (Optional)", sources=["upload", "webcam"]) | |
submit_btn = gr.Button("Get Suggestions", variant="primary") | |
with gr.Column(scale=2): | |
output_image_landmarks = gr.Image(label="Detected Face with Landmarks") | |
output_analysis_info = gr.Markdown(label="Face Analysis & Measurements") | |
output_suggestions = gr.Markdown(label="Suggestions") | |
submit_btn.click( | |
analyze_face_and_suggest_v2, | |
inputs=[front_image_input, side_image_input], | |
outputs=[output_image_landmarks, output_analysis_info, output_suggestions] | |
) | |
gr.Markdown("--- \n ### Notes: \n - **Face Shape Estimation:** Based on ratios of distances between facial landmarks (MediaPipe). The categories (Oval, Round, etc.) and classification rules are experimental. \n - **Landmark Visualization:** Green mesh shows detected facial landmarks. \n - **Model Loading:** Tries `hustvl/yolos-face` first, then `hustvl/yolos-tiny` (person detection) as fallback. Check Space logs for details.") | |
if __name__ == "__main__": | |
# Only launch if at least the fallback detection model and mediapipe loaded | |
if (face_detection_model and face_image_processor and FACE_LABEL_ID != -1) and \ | |
(face_mesh_detector and mp_drawing and drawing_spec): | |
print("Launching Gradio App...") | |
demo.launch() | |
else: | |
print("Gradio app not launched due to critical model loading errors. Please check the logs.") | |
if not (face_detection_model and face_image_processor and FACE_LABEL_ID != -1): | |
print("-> Face detection model failed to load.") | |
if not (face_mesh_detector and mp_drawing and drawing_spec): | |
print("-> MediaPipe landmark model failed to initialize.") |