import gradio as gr import torch from transformers import AutoProcessor, AutoModel from PIL import Image, ImageDraw, ImageFont import numpy as np import random import os import wget import traceback # --- Configuration & Model Loading --- # Device Selection with fallback DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Simplified check print(f"Using device: {DEVICE}") # --- CLIP Setup --- CLIP_MODEL_ID = "openai/clip-vit-base-patch32" clip_processor = None clip_model = None def load_clip_model(): global clip_processor, clip_model if clip_processor is None: try: print(f"Loading CLIP processor: {CLIP_MODEL_ID}...") clip_processor = AutoProcessor.from_pretrained(CLIP_MODEL_ID) print("CLIP processor loaded.") except Exception as e: print(f"Error loading CLIP processor: {e}") traceback.print_exc() # Print traceback return False if clip_model is None: try: print(f"Loading CLIP model: {CLIP_MODEL_ID}...") clip_model = AutoModel.from_pretrained(CLIP_MODEL_ID).to(DEVICE) print(f"CLIP model loaded to {DEVICE}.") except Exception as e: print(f"Error loading CLIP model: {e}") traceback.print_exc() # Print traceback return False return True # --- FastSAM Setup --- FASTSAM_CHECKPOINT = "FastSAM-s.pt" FASTSAM_CHECKPOINT_URL = f"https://huggingface.co/CASIA-IVA-Lab/FastSAM-s/resolve/main/{FASTSAM_CHECKPOINT}" fastsam_model = None fastsam_lib_imported = False FastSAM = None # Define placeholders FastSAMPrompt = None # Define placeholders def check_and_import_fastsam(): global fastsam_lib_imported, FastSAM, FastSAMPrompt # Make sure globals are modified if not fastsam_lib_imported: try: from fastsam import FastSAM as FastSAM_lib, FastSAMPrompt as FastSAMPrompt_lib # Use temp names FastSAM = FastSAM_lib # Assign to global FastSAMPrompt = FastSAMPrompt_lib # Assign to global fastsam_lib_imported = True print("fastsam library imported successfully.") except ImportError as e: print(f"Error: 'fastsam' library not found. Please install it: pip install git+https://github.com/CASIA-IVA-Lab/FastSAM.git") print(f"ImportError: {e}") fastsam_lib_imported = False except Exception as e: print(f"Unexpected error during fastsam import: {e}") traceback.print_exc() fastsam_lib_imported = False return fastsam_lib_imported def download_fastsam_weights(retries=3): if not os.path.exists(FASTSAM_CHECKPOINT): print(f"Downloading FastSAM weights: {FASTSAM_CHECKPOINT} from {FASTSAM_CHECKPOINT_URL}...") for attempt in range(retries): try: # Ensure the directory exists if FASTSAM_CHECKPOINT includes a path os.makedirs(os.path.dirname(FASTSAM_CHECKPOINT) or '.', exist_ok=True) wget.download(FASTSAM_CHECKPOINT_URL, FASTSAM_CHECKPOINT) print("FastSAM weights downloaded.") return True # Return True on successful download except Exception as e: print(f"Attempt {attempt + 1}/{retries} failed to download FastSAM weights: {e}") if os.path.exists(FASTSAM_CHECKPOINT): # Cleanup partial download try: os.remove(FASTSAM_CHECKPOINT) except OSError: pass if attempt + 1 == retries: print("Failed to download weights after all attempts.") return False return False # Should not be reached if loop completes, but added for clarity else: print("FastSAM weights already exist.") return True # Weights exist def load_fastsam_model(): global fastsam_model if fastsam_model is None: if not check_and_import_fastsam(): print("Cannot load FastSAM model due to library import failure.") return False if download_fastsam_weights(): # Ensure FastSAM class is available (might fail if import failed earlier but file exists) if FastSAM is None: print("FastSAM class not available, check import status.") return False try: print(f"Loading FastSAM model: {FASTSAM_CHECKPOINT}...") # Instantiate the imported class fastsam_model = FastSAM(FASTSAM_CHECKPOINT) # Move model to device *after* initialization (common practice) # Note: Check FastSAM docs if it needs explicit .to(DEVICE) or handles it internally # fastsam_model.model.to(DEVICE) # Example if needed, adjust based on FastSAM structure print("FastSAM model loaded.") return True except Exception as e: print(f"Error loading FastSAM model weights or initializing: {e}") traceback.print_exc() return False else: print("FastSAM weights not found or download failed.") return False # Model already loaded return True # --- Processing Functions --- def run_clip_zero_shot(image: Image.Image, text_labels: str): # Keep CLIP as is, seems less likely to be the primary issue if not isinstance(image, Image.Image): print(f"CLIP input is not a PIL Image, type: {type(image)}") # Try to convert if it's a numpy array (common from Gradio) if isinstance(image, np.ndarray): try: image = Image.fromarray(image) print("Converted numpy input to PIL Image for CLIP.") except Exception as e: print(f"Failed to convert numpy array to PIL Image: {e}") return "Error: Invalid image input format.", None else: return "Error: Please provide a valid image.", None if clip_model is None or clip_processor is None: if not load_clip_model(): # Return None for the image part on critical error return "Error: CLIP Model could not be loaded.", None if not text_labels: # Return empty dict and original image if no labels return {}, image labels = [label.strip() for label in text_labels.split(',') if label.strip()] if not labels: # Return empty dict and original image if no valid labels return {}, image print(f"Running CLIP zero-shot classification with labels: {labels}") try: # Ensure image is RGB if image.mode != "RGB": print(f"Converting image from {image.mode} to RGB for CLIP.") image = image.convert("RGB") inputs = clip_processor(text=labels, images=image, return_tensors="pt", padding=True).to(DEVICE) with torch.no_grad(): outputs = clip_model(**inputs) # Calculate probabilities logits_per_image = outputs.logits_per_image # B x N_labels probs = logits_per_image.softmax(dim=1) # Softmax over labels # Create confidences dictionary confidences = {labels[i]: float(probs[0, i].item()) for i in range(len(labels))} print(f"CLIP Confidences: {confidences}") # Return confidences and the original (potentially converted) image return confidences, image except Exception as e: print(f"Error during CLIP processing: {e}") traceback.print_exc() # Return error message and None for image return f"Error during CLIP processing: {e}", None def run_fastsam_segmentation(image_pil: Image.Image, conf_threshold: float = 0.4, iou_threshold: float = 0.9): # Add input type check if not isinstance(image_pil, Image.Image): print(f"FastSAM input is not a PIL Image, type: {type(image_pil)}") if isinstance(image_pil, np.ndarray): try: image_pil = Image.fromarray(image_pil) print("Converted numpy input to PIL Image for FastSAM.") except Exception as e: print(f"Failed to convert numpy array to PIL Image: {e}") # Return None for image on error return None, "Error: Invalid image input format." # Return tuple for consistency else: # Return None for image on error return None, "Error: Please provide a valid image." # Return tuple # Ensure model is loaded if not load_fastsam_model() or not fastsam_lib_imported or FastSAMPrompt is None: # Return None for image on critical error return None, "Error: FastSAM not loaded or library unavailable." print(f"Running FastSAM 'segment everything' with conf={conf_threshold}, iou={iou_threshold}...") output_image = None # Initialize output image status_message = "Processing..." # Initial status try: # Ensure image is RGB if image_pil.mode != "RGB": print(f"Converting image from {image_pil.mode} to RGB for FastSAM.") image_pil_rgb = image_pil.convert("RGB") else: image_pil_rgb = image_pil # Convert PIL Image to NumPy array (RGB) image_np_rgb = np.array(image_pil_rgb) print(f"Input image shape for FastSAM: {image_np_rgb.shape}") # Run FastSAM model # Make sure the arguments match what FastSAM expects everything_results = fastsam_model( image_np_rgb, device=DEVICE, retina_masks=True, imgsz=640, # Or another size FastSAM supports conf=conf_threshold, iou=iou_threshold, verbose=True # Keep verbose for debugging ) # Check if results are valid before creating prompt if everything_results is None or not isinstance(everything_results, list) or len(everything_results) == 0: print("FastSAM model returned None or empty results.") # Return original image and status return image_pil, "FastSAM did not return valid results." # Results might be in a different format, inspect 'everything_results' print(f"Type of everything_results: {type(everything_results)}") print(f"Length of everything_results: {len(everything_results)}") if len(everything_results) > 0: print(f"Type of first element: {type(everything_results[0])}") # Try to access potential attributes like 'masks' if it's an object if hasattr(everything_results[0], 'masks') and everything_results[0].masks is not None: print(f"Masks found in results object, shape: {everything_results[0].masks.data.shape}") else: print("First result element does not have 'masks' attribute or it's None.") # Process results with FastSAMPrompt # Ensure FastSAMPrompt class is available if FastSAMPrompt is None: print("FastSAMPrompt class is not available.") return image_pil, "Error: FastSAMPrompt class not loaded." prompt_process = FastSAMPrompt(image_np_rgb, everything_results, device=DEVICE) ann = prompt_process.everything_prompt() # Get all annotations # Check annotation format - Adjust based on actual FastSAM output structure # Assuming 'ann' is a list and the first element is a dictionary containing masks masks = None if isinstance(ann, list) and len(ann) > 0 and isinstance(ann[0], dict) and 'masks' in ann[0]: mask_tensor = ann[0]['masks'] if mask_tensor is not None and mask_tensor.numel() > 0: # Check if tensor is not None and not empty masks = mask_tensor.cpu().numpy() print(f"Found {len(masks)} masks with shape: {masks.shape}") else: print("Annotation 'masks' tensor is None or empty.") else: print(f"No masks found or annotation format unexpected. ann type: {type(ann)}") if isinstance(ann, list) and len(ann) > 0: print(f"First element of ann: {ann[0]}") # Prepare output image (start with original) output_image = image_pil.copy() # Draw masks if found if masks is not None and len(masks) > 0: # Ensure output_image is RGBA for compositing overlay = Image.new('RGBA', output_image.size, (0, 0, 0, 0)) draw = ImageDraw.Draw(overlay) for i, mask in enumerate(masks): # Ensure mask is boolean/binary before converting binary_mask = (mask > 0) # Use threshold 0 for binary mask from FastSAM output mask_uint8 = binary_mask.astype(np.uint8) * 255 if mask_uint8.max() == 0: # Skip empty masks # print(f"Skipping empty mask {i}") continue color = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255), 180) # RGBA color try: mask_image = Image.fromarray(mask_uint8, mode='L') # Grayscale mask # Draw the mask onto the overlay draw.bitmap((0, 0), mask_image, fill=color) except Exception as draw_err: print(f"Error drawing mask {i}: {draw_err}") traceback.print_exc() continue # Skip this mask # Composite the overlay onto the image try: output_image_rgba = output_image.convert('RGBA') output_image_composited = Image.alpha_composite(output_image_rgba, overlay) output_image = output_image_composited.convert('RGB') # Convert back to RGB for Gradio status_message = f"Segmentation complete. Found {len(masks)} masks." print("Mask drawing and compositing finished.") except Exception as comp_err: print(f"Error during alpha compositing: {comp_err}") traceback.print_exc() output_image = image_pil # Fallback to original image status_message = "Error during mask visualization." else: print("No masks detected or processed for 'segment everything' mode.") status_message = "No segments found or processed." output_image = image_pil # Return original image if no masks # Save for debugging before returning if output_image: try: debug_path = "debug_fastsam_everything_output.png" output_image.save(debug_path) print(f"Saved debug output to {debug_path}") except Exception as save_err: print(f"Failed to save debug image: {save_err}") return output_image, status_message # Return image and status message except Exception as e: print(f"Error during FastSAM 'everything' processing: {e}") traceback.print_exc() # Return original image and error message in case of failure return image_pil, f"Error during processing: {e}" def run_text_prompted_segmentation(image_pil: Image.Image, text_prompts: str, conf_threshold: float = 0.4, iou_threshold: float = 0.9): # Add input type check if not isinstance(image_pil, Image.Image): print(f"FastSAM Text input is not a PIL Image, type: {type(image_pil)}") if isinstance(image_pil, np.ndarray): try: image_pil = Image.fromarray(image_pil) print("Converted numpy input to PIL Image for FastSAM Text.") except Exception as e: print(f"Failed to convert numpy array to PIL Image: {e}") return None, "Error: Invalid image input format." else: return None, "Error: Please provide a valid image." # Ensure model is loaded if not load_fastsam_model() or not fastsam_lib_imported or FastSAMPrompt is None: return image_pil, "Error: FastSAM Model not loaded or library unavailable." # Return original image on load fail if not text_prompts: return image_pil, "Please enter text prompts (e.g., 'person, dog')." prompts = [p.strip() for p in text_prompts.split(',') if p.strip()] if not prompts: return image_pil, "No valid text prompts entered." print(f"Running FastSAM text-prompted segmentation for: {prompts} with conf={conf_threshold}, iou={iou_threshold}") output_image = None status_message = "Processing..." try: # Ensure image is RGB if image_pil.mode != "RGB": print(f"Converting image from {image_pil.mode} to RGB for FastSAM.") image_pil_rgb = image_pil.convert("RGB") else: image_pil_rgb = image_pil image_np_rgb = np.array(image_pil_rgb) print(f"Input image shape for FastSAM Text: {image_np_rgb.shape}") # Run FastSAM once to get all potential segments everything_results = fastsam_model( image_np_rgb, device=DEVICE, retina_masks=True, imgsz=640, # Use consistent args conf=conf_threshold, iou=iou_threshold, verbose=True ) # Check results if everything_results is None or not isinstance(everything_results, list) or len(everything_results) == 0: print("FastSAM model returned None or empty results for text prompt base.") return image_pil, "FastSAM did not return base results." # Initialize FastSAMPrompt if FastSAMPrompt is None: print("FastSAMPrompt class is not available.") return image_pil, "Error: FastSAMPrompt class not loaded." prompt_process = FastSAMPrompt(image_np_rgb, everything_results, device=DEVICE) all_matching_masks = [] found_prompts_details = [] # Store details like 'prompt (count)' # Process each text prompt for text in prompts: print(f" Processing prompt: '{text}'") # Get annotation for the specific text prompt ann = prompt_process.text_prompt(text=text) # Check annotation format and extract masks current_masks = None num_found = 0 # Adjust check based on actual structure of 'ann' for text_prompt if isinstance(ann, list) and len(ann) > 0 and isinstance(ann[0], dict) and 'masks' in ann[0]: mask_tensor = ann[0]['masks'] if mask_tensor is not None and mask_tensor.numel() > 0: current_masks = mask_tensor.cpu().numpy() num_found = len(current_masks) print(f" Found {num_found} mask(s) for '{text}'. Shape: {current_masks.shape}") all_matching_masks.extend(current_masks) # Add found masks to the list else: print(f" Annotation 'masks' tensor is None or empty for '{text}'.") else: print(f" No masks found or annotation format unexpected for '{text}'. ann type: {type(ann)}") if isinstance(ann, list) and len(ann) > 0: print(f" First element of ann for '{text}': {ann[0]}") found_prompts_details.append(f"{text} ({num_found})") # Record count for status # Prepare output image output_image = image_pil.copy() status_message = f"Results: {', '.join(found_prompts_details)}" if found_prompts_details else "No matches found for any prompt." # Draw all collected masks if any were found if all_matching_masks: print(f"Total masks collected across all prompts: {len(all_matching_masks)}") # Stack masks if needed (optional, can draw one by one) # masks_np = np.stack(all_matching_masks, axis=0) # print(f"Total masks stacked shape: {masks_np.shape}") overlay = Image.new('RGBA', output_image.size, (0, 0, 0, 0)) draw = ImageDraw.Draw(overlay) for i, mask in enumerate(all_matching_masks): # Iterate through collected masks binary_mask = (mask > 0) mask_uint8 = binary_mask.astype(np.uint8) * 255 if mask_uint8.max() == 0: continue # Skip empty masks # Assign a unique color per mask or per prompt (using random here) color = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255), 180) try: mask_image = Image.fromarray(mask_uint8, mode='L') draw.bitmap((0, 0), mask_image, fill=color) except Exception as draw_err: print(f"Error drawing collected mask {i}: {draw_err}") traceback.print_exc() continue # Composite the overlay try: output_image_rgba = output_image.convert('RGBA') output_image_composited = Image.alpha_composite(output_image_rgba, overlay) output_image = output_image_composited.convert('RGB') print("Text prompt mask drawing and compositing finished.") except Exception as comp_err: print(f"Error during alpha compositing for text prompts: {comp_err}") traceback.print_exc() output_image = image_pil # Fallback status_message += " (Error during visualization)" else: print("No matching masks found for any text prompt.") # status_message is already set # Save for debugging if output_image: try: debug_path = "debug_fastsam_text_output.png" output_image.save(debug_path) print(f"Saved debug output to {debug_path}") except Exception as save_err: print(f"Failed to save debug image: {save_err}") return output_image, status_message except Exception as e: print(f"Error during FastSAM text-prompted processing: {e}") traceback.print_exc() # Return original image and error message return image_pil, f"Error during processing: {e}" # --- Gradio Interface --- print("Attempting to preload models...") load_clip_model() # Preload CLIP load_fastsam_model() # Preload FastSAM print("Preloading finished (check logs above for errors).") # --- Gradio Interface Definition --- # (Your Gradio Blocks code remains largely the same, but ensure the outputs match the function returns) with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# CLIP & FastSAM Demo") gr.Markdown("Explore Zero-Shot Classification, 'Segment Everything', and Text-Prompted Segmentation.") with gr.Tabs(): # --- CLIP Tab --- with gr.TabItem("CLIP Zero-Shot Classification"): # ... (CLIP UI definition - seems ok) ... clip_button.click( run_clip_zero_shot, inputs=[clip_input_image, clip_text_labels], # Output matches: Label (dict/str), Image (PIL/None) outputs=[clip_output_label, clip_output_image_display] ) # ... (CLIP Examples - seems ok) ... # --- FastSAM Everything Tab --- with gr.TabItem("FastSAM Segment Everything"): gr.Markdown("Upload an image to segment all objects/regions.") with gr.Row(): with gr.Column(scale=1): fastsam_input_image_all = gr.Image(type="pil", label="Input Image") with gr.Row(): fastsam_conf_all = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Confidence Threshold") fastsam_iou_all = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="IoU Threshold") fastsam_button_all = gr.Button("Run FastSAM Segmentation", variant="primary") with gr.Column(scale=1): # Output for the image fastsam_output_image_all = gr.Image(type="pil", label="Segmented Image") # Add a Textbox for status messages/errors fastsam_status_all = gr.Textbox(label="Status", interactive=False) fastsam_button_all.click( run_fastsam_segmentation, inputs=[fastsam_input_image_all, fastsam_conf_all, fastsam_iou_all], # Outputs: Image (PIL/None), Status (str) outputs=[fastsam_output_image_all, fastsam_status_all] # Updated outputs ) # Update examples if needed to match new output structure (add None/str for status) # Note: Examples might need adjustment if they expect only image output gr.Examples( examples=[ ["examples/dogs.jpg", 0.4, 0.9], ["examples/fruits.jpg", 0.5, 0.8], ["examples/lion.jpg", 0.45, 0.9], ], inputs=[fastsam_input_image_all, fastsam_conf_all, fastsam_iou_all], # Need to adjust outputs for examples if function signature changed # This might require a wrapper if examples expect single output # For now, comment out example outputs or adjust function signature for examples outputs=[fastsam_output_image_all, fastsam_status_all], fn=run_fastsam_segmentation, cache_examples=False, # Keep False for debugging ) # --- Text-Prompted Segmentation Tab --- with gr.TabItem("Text-Prompted Segmentation"): gr.Markdown("Upload an image and provide comma-separated prompts (e.g., 'person, dog').") with gr.Row(): with gr.Column(scale=1): prompt_input_image = gr.Image(type="pil", label="Input Image") prompt_text_input = gr.Textbox(label="Comma-Separated Text Prompts", placeholder="e.g., glasses, watch") with gr.Row(): prompt_conf = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Confidence Threshold") prompt_iou = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="IoU Threshold") prompt_button = gr.Button("Segment by Text", variant="primary") with gr.Column(scale=1): # Output Image prompt_output_image = gr.Image(type="pil", label="Text-Prompted Segmentation") # Status Textbox (already exists, correctly) prompt_status_message = gr.Textbox(label="Status", interactive=False) prompt_button.click( run_text_prompted_segmentation, inputs=[prompt_input_image, prompt_text_input, prompt_conf, prompt_iou], # Outputs: Image (PIL/None), Status (str) - Matches function outputs=[prompt_output_image, prompt_status_message] ) # Update examples similarly if needed gr.Examples( examples=[ ["examples/dog_bike.jpg", "person, bicycle", 0.4, 0.9], ["examples/astronaut.jpg", "person, helmet", 0.35, 0.9], ["examples/dogs.jpg", "dog", 0.4, 0.9], ["examples/fruits.jpg", "banana, apple", 0.5, 0.8], ["examples/teacher.jpg", "person, glasses", 0.4, 0.9], ], inputs=[prompt_input_image, prompt_text_input, prompt_conf, prompt_iou], outputs=[prompt_output_image, prompt_status_message], fn=run_text_prompted_segmentation, cache_examples=False, # Keep False for debugging ) # --- Example File Download --- # (Download logic seems okay, ensure 'wget' is installed: pip install wget) if not os.path.exists("examples"): os.makedirs("examples") print("Created 'examples' directory.") example_files = { "astronaut.jpg": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/d1/Astronaut_-_St._Jean_Bay.jpg/640px-Astronaut_-_St._Jean_Bay.jpg", "dog_bike.jpg": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gradio/outputs_multimodal.jpg", "clip_logo.png": "https://raw.githubusercontent.com/openai/CLIP/main/CLIP.png", "dogs.jpg": "https://raw.githubusercontent.com/ultralytics/assets/main/im/image8.jpg", "fruits.jpg": "https://raw.githubusercontent.com/ultralytics/assets/main/im/image9.jpg", "lion.jpg": "https://huggingface.co/spaces/gradio/image-segmentation/resolve/main/images/lion.jpg", "teacher.jpg": "https://images.pexels.com/photos/848117/pexels-photo-848117.jpeg?auto=compress&cs=tinysrgb&w=600" } def download_example_file(filename, url, retries=3): filepath = os.path.join("examples", filename) if not os.path.exists(filepath): print(f"Attempting to download {filename}...") for attempt in range(retries): try: wget.download(url, filepath) print(f"Downloaded {filename} successfully.") return # Exit function on success except Exception as e: print(f"Download attempt {attempt + 1}/{retries} for {filename} failed: {e}") if os.path.exists(filepath): # Clean up partial download try: os.remove(filepath) except OSError: pass if attempt + 1 == retries: print(f"Failed to download {filename} after {retries} attempts.") else: print(f"Example file {filename} already exists.") # Trigger downloads for filename, url in example_files.items(): download_example_file(filename, url) print("Example file check/download complete.") # --- Launch App --- if __name__ == "__main__": print("Launching Gradio Demo...") demo.launch(debug=True) # Keep debug=True