File size: 4,039 Bytes
f583d37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06b2bc8
 
5eeb043
 
 
f583d37
06b2bc8
 
 
 
 
 
 
f583d37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eed280f
f583d37
 
 
 
 
 
 
 
 
950e2fc
c163326
 
 
950e2fc
c163326
 
f583d37
 
5eeb043
 
f583d37
 
 
950e2fc
5eeb043
 
06b2bc8
cb48075
950e2fc
cb48075
 
 
 
 
 
fad567e
 
 
 
 
 
 
 
 
 
 
 
 
f583d37
 
cb48075
5eeb043
fad567e
 
 
 
f583d37
 
545b7e6
f583d37
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import gradio as gr
from PIL import Image
from huggingface_hub import hf_hub_download
import importlib.util
from torchvision import transforms

# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Download model code
class_path = hf_hub_download(repo_id="PerceptCLIP/PerceptCLIP_Emotions", filename="modeling.py")
spec = importlib.util.spec_from_file_location("modeling", class_path)
modeling = importlib.util.module_from_spec(spec)
spec.loader.exec_module(modeling)

from modeling import clip_lora_model

# Emotions model
emotion_model = clip_lora_model().to(device)
emotion_model_path = hf_hub_download(repo_id="PerceptCLIP/PerceptCLIP_Emotions", filename="perceptCLIP_Emotions.pth")
emotion_model.load_state_dict(torch.load(emotion_model_path, map_location=device))
emotion_model.eval()

# Memorability model
mem_model = clip_lora_model(output_dim=1).to(device)
mem_model_path = hf_hub_download(repo_id="PerceptCLIP/PerceptCLIP_Memorability", filename="perceptCLIP_Memorability.pth")
mem_model.load_state_dict(torch.load(mem_model_path, map_location=device))
mem_model.eval()


# Emotion label mapping
idx2label = {
    0: "amusement",
    1: "awe",
    2: "contentment",
    3: "excitement",
    4: "anger",
    5: "disgust",
    6: "fear",
    7: "sadness"
}

# Emoji mapping
emotion_emoji = {
    "amusement": "πŸ˜‚",
    "awe": "😲",
    "contentment": "😊",
    "excitement": "πŸ˜ƒ",
    "anger": "😠",
    "disgust": "🀒",
    "fear": "😱",
    "sadness": "😞"
}

# Image preprocessing
def emo_preprocess(image):
    transform = transforms.Compose([ 
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.4814, 0.4578, 0.4082), std=(0.2686, 0.2613, 0.2758)),
    ])
    return transform(image).unsqueeze(0).to(device)

# Inference function
def predict_emotion(image):
    # If the image is passed as a PIL Image
    if isinstance(image, Image.Image):  
        img = image.convert("RGB")
    else:
        img = Image.open(image).convert("RGB") 
    
    img = emo_preprocess(img)

    with torch.no_grad():
        mem_score = mem_model(img).item()
        outputs = emotion_model(img)
        predicted = outputs.argmax(1).item()

    emotion = idx2label[predicted]
    emoji = emotion_emoji.get(emotion, "❓")  
    return f"{emotion} {emoji}", f"{mem_score:.4f}"
    


# Example images 
example_images = [
    "https://img.freepik.com/free-photo/emotive-excited-female-with-dark-skin-crisp-hair-keeps-hands-clenched-fists-exclaims-with-positiveness-as-achieved-success-her-career-opens-mouth-widely-isolated-white-wall_273609-16443.jpg",
    "https://t4.ftcdn.net/jpg/01/18/44/59/360_F_118445958_NtP7tIsD0CBPyG7Uad7Z2KxVWrsfCPjP.jpg", 
    "https://apnapestcontrol.ca/wp-content/uploads/2019/02/9.jpg",  
    "https://images.pexels.com/photos/1107717/pexels-photo-1107717.jpeg?cs=srgb&dl=pexels-fotios-photos-1107717.jpg&fm=jpg"
]
css = """
    .gradio-container {
        display: flex;
        flex-direction: column;
        align-items: center;
        justify-content: center;
    }
    .gradio-title, .gradio-description {
        text-align: center;
    }
"""

# Create Gradio interface with custom CSS
iface = gr.Interface(
    fn=predict_emotion,
    inputs=gr.Image(type="pil", label="Upload an Image"),
    outputs=[gr.Textbox(label="Emotion"), gr.Textbox(label="Memorability Score")],
    title="PerceptCLIP",
    description="This is an official demo of PerceptCLIP from the paper: [Don’t Judge Before You CLIP: A Unified Approach for Perceptual Tasks](https://arxiv.org/pdf/2503.13260). For each specific task, we fine-tune CLIP with LoRA and an MLP head. Our models achieve state-of-the-art performance. \nThis demo shows results from three models, one for each task - visual emotion analysis, memorability prediction, and image quality assessment.",
    examples=example_images,
    css=css  # Inject the custom CSS
)


if __name__ == "__main__":
    iface.launch()