Amitz244 commited on
Commit
f583d37
Β·
verified Β·
1 Parent(s): 31209fb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from PIL import Image
4
+ from huggingface_hub import hf_hub_download
5
+ import importlib.util
6
+ from torchvision import transforms
7
+
8
+ # Load model
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ # Download model code
12
+ class_path = hf_hub_download(repo_id="PerceptCLIP/PerceptCLIP_Emotions", filename="modeling.py")
13
+ spec = importlib.util.spec_from_file_location("modeling", class_path)
14
+ modeling = importlib.util.module_from_spec(spec)
15
+ spec.loader.exec_module(modeling)
16
+
17
+ # Initialize the model
18
+ from modeling import clip_lora_model
19
+ model = clip_lora_model().to(device)
20
+
21
+ # Load pretrained weights
22
+ model_path = hf_hub_download(repo_id="PerceptCLIP/PerceptCLIP_Emotions", filename="perceptCLIP_Emotions.pth")
23
+ model.load_state_dict(torch.load(model_path, map_location=device))
24
+ model.eval()
25
+
26
+ # Emotion label mapping
27
+ idx2label = {
28
+ 0: "amusement",
29
+ 1: "awe",
30
+ 2: "contentment",
31
+ 3: "excitement",
32
+ 4: "anger",
33
+ 5: "disgust",
34
+ 6: "fear",
35
+ 7: "sadness"
36
+ }
37
+
38
+ # Emoji mapping
39
+ emotion_emoji = {
40
+ "amusement": "πŸ˜‚",
41
+ "awe": "😲",
42
+ "contentment": "😊",
43
+ "excitement": "πŸ˜ƒ",
44
+ "anger": "😠",
45
+ "disgust": "🀒",
46
+ "fear": "😱",
47
+ "sadness": "😞"
48
+ }
49
+
50
+ # Image preprocessing
51
+ def emo_preprocess(image):
52
+ transform = transforms.Compose([
53
+ transforms.Resize(224),
54
+ transforms.CenterCrop(224),
55
+ transforms.ToTensor(),
56
+ transforms.Normalize(mean=(0.4814, 0.4578, 0.4082), std=(0.2686, 0.2613, 0.2758)),
57
+ ])
58
+ return transform(image).unsqueeze(0).to(device)
59
+
60
+ # Inference function
61
+ def predict_emotion(image):
62
+ image = Image.open(image).convert("RGB")
63
+ image = emo_preprocess(image)
64
+
65
+ with torch.no_grad():
66
+ outputs = model(image)
67
+ predicted = outputs.argmax(1).item()
68
+
69
+ emotion = idx2label[predicted]
70
+ emoji = emotion_emoji.get(emotion, "❓") # Default to "?" if no emoji found
71
+ return f"{emotion} {emoji}"
72
+
73
+ # Create Gradio interface
74
+ iface = gr.Interface(
75
+ fn=predict_emotion,
76
+ inputs=gr.inputs.Image(type="pil", label="Upload an Image"),
77
+ outputs=gr.outputs.Textbox(label="Emotion + Emoji"),
78
+ title="PerceptCLIP-Emotions",
79
+ description="This model predicts the emotion evoked by an image and returns the corresponding emoji along with the emotion name."
80
+ )
81
+
82
+ if __name__ == "__main__":
83
+ iface.launch()