Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn.functional as F | |
from torch import Tensor | |
import spaces | |
import numpy as np | |
from PIL import Image | |
import json, os, random | |
import gradio as gr | |
import torchvision.transforms.functional as TF | |
from safetensors.torch import load_file # Import the load_file function from safetensors | |
from matplotlib import cm | |
from huggingface_hub import hf_hub_download | |
from typing import Tuple | |
from models import get_model | |
def resize_density_map(x: Tensor, size: Tuple[int, int]) -> Tensor: | |
x_sum = torch.sum(x, dim=(-1, -2)) | |
x = F.interpolate(x, size=size, mode="bilinear") | |
scale_factor = torch.nan_to_num(torch.sum(x, dim=(-1, -2)) / x_sum, nan=0.0, posinf=0.0, neginf=0.0) | |
return x * scale_factor | |
def init_seeds(seed: int) -> None: | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
mean = (0.485, 0.456, 0.406) | |
std = (0.229, 0.224, 0.225) | |
alpha = 0.8 | |
init_seeds(42) | |
# ----------------------------- | |
# Define the model architecture | |
# ----------------------------- | |
truncation = 4 | |
reduction = 8 | |
granularity = "fine" | |
anchor_points = "average" | |
input_size = 224 | |
# Comment the lines below to test non-CLIP models. | |
prompt_type = "word" | |
num_vpt = 32 | |
vpt_drop = 0. | |
deep_vpt = True | |
repo_id = "Yiming-M/CLIP-EBC" | |
model_configs = { | |
"CLIP_EBC_ViT_L_14": { | |
"model_name": "clip_vit_l_14", | |
"filename": "nwpu_weights/CLIP_EBC_ViT_L_14/model.safetensors", | |
}, | |
"CLIP_EBC_ViT_B_16": { | |
"model_name": "clip_vit_b_16", | |
"filename": "nwpu_weights/CLIP_EBC_ViT_B_16/model.safetensors", | |
}, | |
} | |
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
device = "cuda" | |
if truncation is None: # regression, no truncation. | |
bins, anchor_points = None, None | |
else: | |
with open(os.path.join("configs", f"reduction_{reduction}.json"), "r") as f: | |
config = json.load(f)[str(truncation)]["nwpu"] | |
bins = config["bins"][granularity] | |
anchor_points = config["anchor_points"][granularity]["average"] if anchor_points == "average" else config["anchor_points"][granularity]["middle"] | |
bins = [(float(b[0]), float(b[1])) for b in bins] | |
anchor_points = [float(p) for p in anchor_points] | |
# Use a global reference to store the model instance | |
loaded_model = None | |
def load_model(model_choice: str): | |
global loaded_model | |
config = model_configs[model_choice] | |
model_name = config["model_name"] | |
filename = config["filename"] | |
# Prepare bins and anchor_points if using classification | |
if truncation is None: | |
bins_, anchor_points_ = None, None | |
else: | |
with open(os.path.join("configs", f"reduction_{reduction}.json"), "r") as f: | |
config_json = json.load(f)[str(truncation)]["nwpu"] | |
bins_ = config_json["bins"][granularity] | |
anchor_points_ = config_json["anchor_points"][granularity]["average"] if anchor_points == "average" else config_json["anchor_points"][granularity]["middle"] | |
bins_ = [(float(b[0]), float(b[1])) for b in bins_] | |
anchor_points_ = [float(p) for p in anchor_points_] | |
# Build model | |
model = get_model( | |
backbone=model_name, | |
input_size=input_size, | |
reduction=reduction, | |
bins=bins_, | |
anchor_points=anchor_points_, | |
prompt_type=prompt_type, | |
num_vpt=num_vpt, | |
vpt_drop=vpt_drop, | |
deep_vpt=deep_vpt, | |
) | |
weights_path = hf_hub_download(repo_id, filename) | |
state_dict = load_file(weights_path) | |
new_state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()} | |
model.load_state_dict(new_state_dict) | |
model.eval() | |
loaded_model = model | |
# ----------------------------- | |
# Preprocessing function | |
# ----------------------------- | |
# Adjust the image transforms to match what your model expects. | |
def transform(image: Image.Image): | |
assert isinstance(image, Image.Image), "Input must be a PIL Image" | |
image_tensor = TF.to_tensor(image) | |
image_height, image_width = image_tensor.shape[-2:] | |
if image_height < input_size or image_width < input_size: | |
# Find the ratio to resize the image while maintaining the aspect ratio | |
ratio = max(input_size / image_height, input_size / image_width) | |
new_height = int(image_height * ratio) + 1 | |
new_width = int(image_width * ratio) + 1 | |
image_tensor = TF.resize(image_tensor, (new_height, new_width), interpolation=TF.InterpolationMode.BICUBIC, antialias=True) | |
image_tensor = TF.normalize(image_tensor, mean=mean, std=std) | |
return image_tensor.unsqueeze(0) # Add batch dimension | |
# ----------------------------- | |
# Inference function | |
# ----------------------------- | |
def predict(image: Image.Image, model_choice: str = "CLIP_EBC_ViT_B_16"): | |
""" | |
Given an input image, preprocess it, run the model to obtain a density map, | |
compute the total crowd count, and prepare the density map for display. | |
""" | |
global loaded_model | |
if loaded_model is None or model_configs[model_choice]["model_name"] not in loaded_model.__class__.__name__: | |
load_model(model_choice) | |
loaded_model.to(device) | |
# Preprocess the image | |
input_width, input_height = image.size | |
input_tensor = transform(image).to(device) # shape: (1, 3, H, W) | |
with torch.no_grad(): | |
density_map = loaded_model(input_tensor) # expected shape: (1, 1, H, W) | |
total_count = density_map.sum().item() | |
resized_density_map = resize_density_map(density_map, (input_height, input_width)).cpu().squeeze().numpy() | |
# Normalize the density map for display purposes | |
eps = 1e-8 | |
density_map_norm = (resized_density_map - resized_density_map.min()) / (resized_density_map.max() - resized_density_map.min() + eps) | |
# Apply a colormap (e.g., 'jet') to get an RGBA image | |
colormap = cm.get_cmap("jet") | |
# The colormap returns values in [0,1]. Scale to [0,255] and convert to uint8. | |
density_map_color = (colormap(density_map_norm) * 255).astype(np.uint8) | |
density_map_color_img = Image.fromarray(density_map_color).convert("RGBA") | |
# Ensure the original image is in RGBA format. | |
image_rgba = image.convert("RGBA") | |
overlayed_image = Image.blend(image_rgba, density_map_color_img, alpha=alpha) | |
return image, overlayed_image, f"Predicted Count: {total_count:.2f}" | |
# ----------------------------- | |
# Build Gradio Interface using Blocks for a two-column layout | |
# ----------------------------- | |
with gr.Blocks() as demo: | |
gr.Markdown("# Crowd Counting by CLIP-EBC (Pre-trained on NWPU-Crowd)") | |
gr.Markdown("Upload an image or select an example below to see the predicted crowd density map and total count.") | |
with gr.Row(): | |
with gr.Column(): | |
model_choice = gr.Dropdown( | |
choices=list(model_configs.keys()), | |
value="CLIP_EBC_ViT_B_16", | |
label="Select Model" | |
) | |
input_img = gr.Image(label="Input Image", sources=["upload", "clipboard"], type="pil") | |
submit_btn = gr.Button("Predict") | |
with gr.Column(): | |
output_img = gr.Image(label="Predicted Density Map", type="pil") | |
output_text = gr.Textbox(label="Total Count") | |
submit_btn.click(fn=predict, inputs=[input_img, model_choice], outputs=[input_img, output_img, output_text]) | |
gr.Examples( | |
examples=[ | |
["example1.jpg"], | |
["example2.jpg"], | |
["example3.jpg"], | |
["example4.jpg"], | |
["example5.jpg"], | |
], | |
inputs=input_img, | |
label="Try an example" | |
) | |
demo.launch() |