CLIP-EBC / app.py
Yiming-M's picture
Update app.py
68d6ff9 verified
import torch
import torch.nn.functional as F
from torch import Tensor
import spaces
import numpy as np
from PIL import Image
import json, os, random
import gradio as gr
import torchvision.transforms.functional as TF
from safetensors.torch import load_file # Import the load_file function from safetensors
from matplotlib import cm
from huggingface_hub import hf_hub_download
from typing import Tuple
from models import get_model
def resize_density_map(x: Tensor, size: Tuple[int, int]) -> Tensor:
x_sum = torch.sum(x, dim=(-1, -2))
x = F.interpolate(x, size=size, mode="bilinear")
scale_factor = torch.nan_to_num(torch.sum(x, dim=(-1, -2)) / x_sum, nan=0.0, posinf=0.0, neginf=0.0)
return x * scale_factor
def init_seeds(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
alpha = 0.8
init_seeds(42)
# -----------------------------
# Define the model architecture
# -----------------------------
truncation = 4
reduction = 8
granularity = "fine"
anchor_points = "average"
input_size = 224
# Comment the lines below to test non-CLIP models.
prompt_type = "word"
num_vpt = 32
vpt_drop = 0.
deep_vpt = True
repo_id = "Yiming-M/CLIP-EBC"
model_configs = {
"CLIP_EBC_ViT_L_14": {
"model_name": "clip_vit_l_14",
"filename": "nwpu_weights/CLIP_EBC_ViT_L_14/model.safetensors",
},
"CLIP_EBC_ViT_B_16": {
"model_name": "clip_vit_b_16",
"filename": "nwpu_weights/CLIP_EBC_ViT_B_16/model.safetensors",
},
}
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cuda"
if truncation is None: # regression, no truncation.
bins, anchor_points = None, None
else:
with open(os.path.join("configs", f"reduction_{reduction}.json"), "r") as f:
config = json.load(f)[str(truncation)]["nwpu"]
bins = config["bins"][granularity]
anchor_points = config["anchor_points"][granularity]["average"] if anchor_points == "average" else config["anchor_points"][granularity]["middle"]
bins = [(float(b[0]), float(b[1])) for b in bins]
anchor_points = [float(p) for p in anchor_points]
# Use a global reference to store the model instance
loaded_model = None
def load_model(model_choice: str):
global loaded_model
config = model_configs[model_choice]
model_name = config["model_name"]
filename = config["filename"]
# Prepare bins and anchor_points if using classification
if truncation is None:
bins_, anchor_points_ = None, None
else:
with open(os.path.join("configs", f"reduction_{reduction}.json"), "r") as f:
config_json = json.load(f)[str(truncation)]["nwpu"]
bins_ = config_json["bins"][granularity]
anchor_points_ = config_json["anchor_points"][granularity]["average"] if anchor_points == "average" else config_json["anchor_points"][granularity]["middle"]
bins_ = [(float(b[0]), float(b[1])) for b in bins_]
anchor_points_ = [float(p) for p in anchor_points_]
# Build model
model = get_model(
backbone=model_name,
input_size=input_size,
reduction=reduction,
bins=bins_,
anchor_points=anchor_points_,
prompt_type=prompt_type,
num_vpt=num_vpt,
vpt_drop=vpt_drop,
deep_vpt=deep_vpt,
)
weights_path = hf_hub_download(repo_id, filename)
state_dict = load_file(weights_path)
new_state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
model.load_state_dict(new_state_dict)
model.eval()
loaded_model = model
# -----------------------------
# Preprocessing function
# -----------------------------
# Adjust the image transforms to match what your model expects.
def transform(image: Image.Image):
assert isinstance(image, Image.Image), "Input must be a PIL Image"
image_tensor = TF.to_tensor(image)
image_height, image_width = image_tensor.shape[-2:]
if image_height < input_size or image_width < input_size:
# Find the ratio to resize the image while maintaining the aspect ratio
ratio = max(input_size / image_height, input_size / image_width)
new_height = int(image_height * ratio) + 1
new_width = int(image_width * ratio) + 1
image_tensor = TF.resize(image_tensor, (new_height, new_width), interpolation=TF.InterpolationMode.BICUBIC, antialias=True)
image_tensor = TF.normalize(image_tensor, mean=mean, std=std)
return image_tensor.unsqueeze(0) # Add batch dimension
# -----------------------------
# Inference function
# -----------------------------
@spaces.GPU(duration=120)
def predict(image: Image.Image, model_choice: str = "CLIP_EBC_ViT_B_16"):
"""
Given an input image, preprocess it, run the model to obtain a density map,
compute the total crowd count, and prepare the density map for display.
"""
global loaded_model
if loaded_model is None or model_configs[model_choice]["model_name"] not in loaded_model.__class__.__name__:
load_model(model_choice)
loaded_model.to(device)
# Preprocess the image
input_width, input_height = image.size
input_tensor = transform(image).to(device) # shape: (1, 3, H, W)
with torch.no_grad():
density_map = loaded_model(input_tensor) # expected shape: (1, 1, H, W)
total_count = density_map.sum().item()
resized_density_map = resize_density_map(density_map, (input_height, input_width)).cpu().squeeze().numpy()
# Normalize the density map for display purposes
eps = 1e-8
density_map_norm = (resized_density_map - resized_density_map.min()) / (resized_density_map.max() - resized_density_map.min() + eps)
# Apply a colormap (e.g., 'jet') to get an RGBA image
colormap = cm.get_cmap("jet")
# The colormap returns values in [0,1]. Scale to [0,255] and convert to uint8.
density_map_color = (colormap(density_map_norm) * 255).astype(np.uint8)
density_map_color_img = Image.fromarray(density_map_color).convert("RGBA")
# Ensure the original image is in RGBA format.
image_rgba = image.convert("RGBA")
overlayed_image = Image.blend(image_rgba, density_map_color_img, alpha=alpha)
return image, overlayed_image, f"Predicted Count: {total_count:.2f}"
# -----------------------------
# Build Gradio Interface using Blocks for a two-column layout
# -----------------------------
with gr.Blocks() as demo:
gr.Markdown("# Crowd Counting by CLIP-EBC (Pre-trained on NWPU-Crowd)")
gr.Markdown("Upload an image or select an example below to see the predicted crowd density map and total count.")
with gr.Row():
with gr.Column():
model_choice = gr.Dropdown(
choices=list(model_configs.keys()),
value="CLIP_EBC_ViT_B_16",
label="Select Model"
)
input_img = gr.Image(label="Input Image", sources=["upload", "clipboard"], type="pil")
submit_btn = gr.Button("Predict")
with gr.Column():
output_img = gr.Image(label="Predicted Density Map", type="pil")
output_text = gr.Textbox(label="Total Count")
submit_btn.click(fn=predict, inputs=[input_img, model_choice], outputs=[input_img, output_img, output_text])
gr.Examples(
examples=[
["example1.jpg"],
["example2.jpg"],
["example3.jpg"],
["example4.jpg"],
["example5.jpg"],
],
inputs=input_img,
label="Try an example"
)
demo.launch()