import torch import gradio as gr from PIL import Image from huggingface_hub import hf_hub_download import importlib.util from torchvision import transforms # Load model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Download model code class_path = hf_hub_download(repo_id="PerceptCLIP/PerceptCLIP_Emotions", filename="modeling.py") spec = importlib.util.spec_from_file_location("modeling", class_path) modeling = importlib.util.module_from_spec(spec) spec.loader.exec_module(modeling) from modeling import clip_lora_model # Emotions model emotion_model = clip_lora_model().to(device) emotion_model_path = hf_hub_download(repo_id="PerceptCLIP/PerceptCLIP_Emotions", filename="perceptCLIP_Emotions.pth") emotion_model.load_state_dict(torch.load(emotion_model_path, map_location=device)) emotion_model.eval() # Memorability model mem_model = clip_lora_model(output_dim=1).to(device) mem_model_path = hf_hub_download(repo_id="PerceptCLIP/PerceptCLIP_Memorability", filename="perceptCLIP_Memorability.pth") mem_model.load_state_dict(torch.load(mem_model_path, map_location=device)) mem_model.eval() # Emotion label mapping idx2label = { 0: "amusement", 1: "awe", 2: "contentment", 3: "excitement", 4: "anger", 5: "disgust", 6: "fear", 7: "sadness" } # Emoji mapping emotion_emoji = { "amusement": "😂", "awe": "😲", "contentment": "😊", "excitement": "😃", "anger": "😠", "disgust": "🤢", "fear": "😱", "sadness": "😞" } # Image preprocessing def emo_preprocess(image): transform = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=(0.4814, 0.4578, 0.4082), std=(0.2686, 0.2613, 0.2758)), ]) return transform(image).unsqueeze(0).to(device) # Inference function def predict_emotion(image): # If the image is passed as a PIL Image if isinstance(image, Image.Image): img = image.convert("RGB") else: img = Image.open(image).convert("RGB") img = emo_preprocess(img) with torch.no_grad(): mem_score = mem_model(img).item() outputs = emotion_model(img) predicted = outputs.argmax(1).item() emotion = idx2label[predicted] emoji = emotion_emoji.get(emotion, "❓") return f"{emotion} {emoji}", f"{mem_score:.4f}" # Example images example_images = [ "https://img.freepik.com/free-photo/emotive-excited-female-with-dark-skin-crisp-hair-keeps-hands-clenched-fists-exclaims-with-positiveness-as-achieved-success-her-career-opens-mouth-widely-isolated-white-wall_273609-16443.jpg", "https://t4.ftcdn.net/jpg/01/18/44/59/360_F_118445958_NtP7tIsD0CBPyG7Uad7Z2KxVWrsfCPjP.jpg", "https://apnapestcontrol.ca/wp-content/uploads/2019/02/9.jpg", "https://images.pexels.com/photos/1107717/pexels-photo-1107717.jpeg?cs=srgb&dl=pexels-fotios-photos-1107717.jpg&fm=jpg" ] # Create Gradio interface iface = gr.Interface( fn=predict_emotion, inputs=gr.Image(type="pil", label="Upload an Image"), outputs=[gr.Textbox(label="Emotion"), gr.Textbox(label="Memorability Score")], title="PerceptCLIP-Emotions", description="This model predicts the emotion evoked by an image.", examples=example_images ) if __name__ == "__main__": iface.launch()