import torch import torch.nn.functional as F from torch import Tensor import spaces import numpy as np from PIL import Image import json, os, random import gradio as gr import torchvision.transforms.functional as TF from safetensors.torch import load_file # Import the load_file function from safetensors from matplotlib import cm from huggingface_hub import hf_hub_download from typing import Tuple from models import get_model def resize_density_map(x: Tensor, size: Tuple[int, int]) -> Tensor: x_sum = torch.sum(x, dim=(-1, -2)) x = F.interpolate(x, size=size, mode="bilinear") scale_factor = torch.nan_to_num(torch.sum(x, dim=(-1, -2)) / x_sum, nan=0.0, posinf=0.0, neginf=0.0) return x * scale_factor def init_seeds(seed: int) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) alpha = 0.8 init_seeds(42) # ----------------------------- # Define the model architecture # ----------------------------- truncation = 4 reduction = 8 granularity = "fine" anchor_points = "average" input_size = 224 # Comment the lines below to test non-CLIP models. prompt_type = "word" num_vpt = 32 vpt_drop = 0. deep_vpt = True repo_id = "Yiming-M/CLIP-EBC" model_configs = { "CLIP_EBC_ViT_L_14": { "model_name": "clip_vit_l_14", "filename": "nwpu_weights/CLIP_EBC_ViT_L_14/model.safetensors", }, "CLIP_EBC_ViT_B_16": { "model_name": "clip_vit_b_16", "filename": "nwpu_weights/CLIP_EBC_ViT_B_16/model.safetensors", }, } # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = "cuda" if truncation is None: # regression, no truncation. bins, anchor_points = None, None else: with open(os.path.join("configs", f"reduction_{reduction}.json"), "r") as f: config = json.load(f)[str(truncation)]["nwpu"] bins = config["bins"][granularity] anchor_points = config["anchor_points"][granularity]["average"] if anchor_points == "average" else config["anchor_points"][granularity]["middle"] bins = [(float(b[0]), float(b[1])) for b in bins] anchor_points = [float(p) for p in anchor_points] # Use a global reference to store the model instance loaded_model = None def load_model(model_choice: str): global loaded_model config = model_configs[model_choice] model_name = config["model_name"] filename = config["filename"] # Prepare bins and anchor_points if using classification if truncation is None: bins_, anchor_points_ = None, None else: with open(os.path.join("configs", f"reduction_{reduction}.json"), "r") as f: config_json = json.load(f)[str(truncation)]["nwpu"] bins_ = config_json["bins"][granularity] anchor_points_ = config_json["anchor_points"][granularity]["average"] if anchor_points == "average" else config_json["anchor_points"][granularity]["middle"] bins_ = [(float(b[0]), float(b[1])) for b in bins_] anchor_points_ = [float(p) for p in anchor_points_] # Build model model = get_model( backbone=model_name, input_size=input_size, reduction=reduction, bins=bins_, anchor_points=anchor_points_, prompt_type=prompt_type, num_vpt=num_vpt, vpt_drop=vpt_drop, deep_vpt=deep_vpt, ) weights_path = hf_hub_download(repo_id, filename) state_dict = load_file(weights_path) new_state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()} model.load_state_dict(new_state_dict) model.eval() loaded_model = model # ----------------------------- # Preprocessing function # ----------------------------- # Adjust the image transforms to match what your model expects. def transform(image: Image.Image): assert isinstance(image, Image.Image), "Input must be a PIL Image" image_tensor = TF.to_tensor(image) image_height, image_width = image_tensor.shape[-2:] if image_height < input_size or image_width < input_size: # Find the ratio to resize the image while maintaining the aspect ratio ratio = max(input_size / image_height, input_size / image_width) new_height = int(image_height * ratio) + 1 new_width = int(image_width * ratio) + 1 image_tensor = TF.resize(image_tensor, (new_height, new_width), interpolation=TF.InterpolationMode.BICUBIC, antialias=True) image_tensor = TF.normalize(image_tensor, mean=mean, std=std) return image_tensor.unsqueeze(0) # Add batch dimension # ----------------------------- # Inference function # ----------------------------- @spaces.GPU(duration=120) def predict(image: Image.Image, model_choice: str = "CLIP_EBC_ViT_B_16"): """ Given an input image, preprocess it, run the model to obtain a density map, compute the total crowd count, and prepare the density map for display. """ global loaded_model if loaded_model is None or model_configs[model_choice]["model_name"] not in loaded_model.__class__.__name__: load_model(model_choice) loaded_model.to(device) # Preprocess the image input_width, input_height = image.size input_tensor = transform(image).to(device) # shape: (1, 3, H, W) with torch.no_grad(): density_map = loaded_model(input_tensor) # expected shape: (1, 1, H, W) total_count = density_map.sum().item() resized_density_map = resize_density_map(density_map, (input_height, input_width)).cpu().squeeze().numpy() # Normalize the density map for display purposes eps = 1e-8 density_map_norm = (resized_density_map - resized_density_map.min()) / (resized_density_map.max() - resized_density_map.min() + eps) # Apply a colormap (e.g., 'jet') to get an RGBA image colormap = cm.get_cmap("jet") # The colormap returns values in [0,1]. Scale to [0,255] and convert to uint8. density_map_color = (colormap(density_map_norm) * 255).astype(np.uint8) density_map_color_img = Image.fromarray(density_map_color).convert("RGBA") # Ensure the original image is in RGBA format. image_rgba = image.convert("RGBA") overlayed_image = Image.blend(image_rgba, density_map_color_img, alpha=alpha) return image, overlayed_image, f"Predicted Count: {total_count:.2f}" # ----------------------------- # Build Gradio Interface using Blocks for a two-column layout # ----------------------------- with gr.Blocks() as demo: gr.Markdown("# Crowd Counting by CLIP-EBC (Pre-trained on NWPU-Crowd)") gr.Markdown("Upload an image or select an example below to see the predicted crowd density map and total count.") with gr.Row(): with gr.Column(): model_choice = gr.Dropdown( choices=list(model_configs.keys()), value="CLIP_EBC_ViT_B_16", label="Select Model" ) input_img = gr.Image(label="Input Image", sources=["upload", "clipboard"], type="pil") submit_btn = gr.Button("Predict") with gr.Column(): output_img = gr.Image(label="Predicted Density Map", type="pil") output_text = gr.Textbox(label="Total Count") submit_btn.click(fn=predict, inputs=[input_img, model_choice], outputs=[input_img, output_img, output_text]) gr.Examples( examples=[ ["example1.jpg"], ["example2.jpg"], ["example3.jpg"], ["example4.jpg"], ["example5.jpg"], ], inputs=input_img, label="Try an example" ) demo.launch()