Amitz244 commited on
Commit
5eeb043
·
verified ·
1 Parent(s): 5e301d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -7
app.py CHANGED
@@ -16,12 +16,12 @@ spec.loader.exec_module(modeling)
16
 
17
  # Initialize the model
18
  from modeling import clip_lora_model
19
- model = clip_lora_model().to(device)
20
 
21
  # Load pretrained weights
22
- model_path = hf_hub_download(repo_id="PerceptCLIP/PerceptCLIP_Emotions", filename="perceptCLIP_Emotions.pth")
23
- model.load_state_dict(torch.load(model_path, map_location=device))
24
- model.eval()
25
 
26
  # Emotion label mapping
27
  idx2label = {
@@ -68,12 +68,18 @@ def predict_emotion(image):
68
  img = emo_preprocess(img)
69
 
70
  with torch.no_grad():
71
- outputs = model(img)
 
72
  predicted = outputs.argmax(1).item()
73
 
74
  emotion = idx2label[predicted]
75
  emoji = emotion_emoji.get(emotion, "❓")
76
- return f"{emotion} {emoji}"
 
 
 
 
 
77
 
78
  # Example images
79
  example_images = [
@@ -87,7 +93,7 @@ example_images = [
87
  iface = gr.Interface(
88
  fn=predict_emotion,
89
  inputs=gr.Image(type="pil", label="Upload an Image"),
90
- outputs=gr.Textbox(label="Emotion"),
91
  title="PerceptCLIP-Emotions",
92
  description="This model predicts the emotion evoked by an image.",
93
  examples=example_images
 
16
 
17
  # Initialize the model
18
  from modeling import clip_lora_model
19
+ emotion_model = clip_lora_model().to(device)
20
 
21
  # Load pretrained weights
22
+ emotion_model_path = hf_hub_download(repo_id="PerceptCLIP/PerceptCLIP_Emotions", filename="perceptCLIP_Emotions.pth")
23
+ emotion_model.load_state_dict(torch.load(emotion_model_path, map_location=device))
24
+ emotion_model.eval()
25
 
26
  # Emotion label mapping
27
  idx2label = {
 
68
  img = emo_preprocess(img)
69
 
70
  with torch.no_grad():
71
+ mem_score = mem_model(img).item()
72
+ outputs = emotion_model(img)
73
  predicted = outputs.argmax(1).item()
74
 
75
  emotion = idx2label[predicted]
76
  emoji = emotion_emoji.get(emotion, "❓")
77
+ return f"{emotion} {emoji}", f"{mem_score:.4f}"
78
+
79
+ mem_model = clip_lora_model(output_dim=1).to(device)
80
+ mem_model_path = hf_hub_download(repo_id="PerceptCLIP/PerceptCLIP_Memorability", filename="perceptCLIP_Memorability.pth")
81
+ mem_model.load_state_dict(torch.load(mem_model_path, map_location=device))
82
+ mem_model.eval()
83
 
84
  # Example images
85
  example_images = [
 
93
  iface = gr.Interface(
94
  fn=predict_emotion,
95
  inputs=gr.Image(type="pil", label="Upload an Image"),
96
+ outputs=[gr.Textbox(label="Emotion"), gr.Textbox(label="Memorability Score")],
97
  title="PerceptCLIP-Emotions",
98
  description="This model predicts the emotion evoked by an image.",
99
  examples=example_images