import torch import gradio as gr from PIL import Image from huggingface_hub import hf_hub_download import importlib.util from torchvision import transforms import random import numpy as np def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # Load model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Download model code class_path = hf_hub_download(repo_id="PerceptCLIP/PerceptCLIP_Emotions", filename="modeling.py") spec = importlib.util.spec_from_file_location("modeling", class_path) modeling = importlib.util.module_from_spec(spec) spec.loader.exec_module(modeling) from modeling import clip_lora_model # Emotions model emotion_model = clip_lora_model().to(device) emotion_model_path = hf_hub_download(repo_id="PerceptCLIP/PerceptCLIP_Emotions", filename="perceptCLIP_Emotions.pth") emotion_model.load_state_dict(torch.load(emotion_model_path, map_location=device)) emotion_model.eval() # Memorability model mem_model = clip_lora_model(output_dim=1).to(device) mem_model_path = hf_hub_download(repo_id="PerceptCLIP/PerceptCLIP_Memorability", filename="perceptCLIP_Memorability.pth") mem_model.load_state_dict(torch.load(mem_model_path, map_location=device)) mem_model.eval() # IQA model iqa_model = clip_lora_model(output_dim=1).to(device) iqa_model_path = hf_hub_download(repo_id="PerceptCLIP/PerceptCLIP_IQA", filename="perceptCLIP_IQA.pth") iqa_model.load_state_dict(torch.load(iqa_model_path, map_location=device)) iqa_model.eval() # Emotion label mapping idx2label = { 0: "amusement", 1: "awe", 2: "contentment", 3: "excitement", 4: "anger", 5: "disgust", 6: "fear", 7: "sadness" } # Emoji mapping emotion_emoji = { "amusement": "πŸ˜‚", "awe": "😲", "contentment": "😊", "excitement": "πŸ˜ƒ", "anger": "😠", "disgust": "🀒", "fear": "😱", "sadness": "😞" } # Image preprocessing def emo_mem_preprocess(image): transform = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=(0.4814, 0.4578, 0.4082), std=(0.2686, 0.2613, 0.2758)), ]) return transform(image).unsqueeze(0).to(device) def IQA_preprocess(): random.seed(3407) transform = transforms.Compose([ transforms.Resize((512,384)), transforms.RandomCrop(size=(224,224)), transforms.ToTensor(), transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) ]) return transform set_seed(3407) # Inference function def predict_percept(image): # If the image is passed as a PIL Image if isinstance(image, Image.Image): img = image.convert("RGB") else: img = Image.open(image).convert("RGB") batch = torch.stack([IQA_preprocess()(image) for _ in range(15)]).to(device) # Shape: (15, 3, 224, 224) img = emo_mem_preprocess(img) with torch.no_grad(): iqa_score = iqa_model(batch).cpu().numpy() mem_score = mem_model(img).item() outputs = emotion_model(img) predicted = outputs.argmax(1).item() iqa_score = np.mean(iqa_score) min_iqa_pred = -6.52 max_iqa_pred = 3.11 normalized_iqa_score = ((iqa_score - min_iqa_pred) / (max_iqa_pred - min_iqa_pred)) emotion = idx2label[predicted] emoji = emotion_emoji.get(emotion, "❓") return f"{emotion} {emoji}", f"{mem_score:.4f}", f"{normalized_iqa_score:.4f}" # Example images example_images = [ "https://webneel.com/daily/sites/default/files/images/daily/02-2013/3-motion-blur-speed-photography.jpg", "https://img.freepik.com/free-photo/emotive-excited-female-with-dark-skin-crisp-hair-keeps-hands-clenched-fists-exclaims-with-positiveness-as-achieved-success-her-career-opens-mouth-widely-isolated-white-wall_273609-16443.jpg", "https://t4.ftcdn.net/jpg/01/18/44/59/360_F_118445958_NtP7tIsD0CBPyG7Uad7Z2KxVWrsfCPjP.jpg", "https://apnapestcontrol.ca/wp-content/uploads/2019/02/9.jpg", "https://images.pexels.com/photos/1107717/pexels-photo-1107717.jpeg?cs=srgb&dl=pexels-fotios-photos-1107717.jpg&fm=jpg", "https://cdn.prod.website-files.com/60e4d0d0155e62117f4faef3/61fab92edbb1ccbc7d12c167_Brian-Matiash-Puppy.jpeg", ] # Create Gradio interface with custom CSS iface = gr.Interface( fn=predict_percept, inputs=gr.Image(type="pil", label="Upload an Image"), outputs=[gr.Textbox(label="Emotion"), gr.Textbox(label="Memorability Score"), gr.Textbox(label="IQA Score")], title="PerceptCLIP", description="This is an official demo of PerceptCLIP from the paper: [Don’t Judge Before You CLIP: A Unified Approach for Perceptual Tasks](https://arxiv.org/pdf/2503.13260). For each specific task, we fine-tune CLIP with LoRA and an MLP head. Our models achieve state-of-the-art performance. \nThis demo shows results from three models, each corresponding to a different task - visual emotion analysis, memorability prediction, and image quality assessment.", examples=example_images ) if __name__ == "__main__": iface.launch()