Amitz244 commited on
Commit
06b2bc8
·
verified ·
1 Parent(s): 5eeb043

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -7
app.py CHANGED
@@ -14,15 +14,21 @@ spec = importlib.util.spec_from_file_location("modeling", class_path)
14
  modeling = importlib.util.module_from_spec(spec)
15
  spec.loader.exec_module(modeling)
16
 
17
- # Initialize the model
18
  from modeling import clip_lora_model
19
- emotion_model = clip_lora_model().to(device)
20
 
21
- # Load pretrained weights
 
22
  emotion_model_path = hf_hub_download(repo_id="PerceptCLIP/PerceptCLIP_Emotions", filename="perceptCLIP_Emotions.pth")
23
  emotion_model.load_state_dict(torch.load(emotion_model_path, map_location=device))
24
  emotion_model.eval()
25
 
 
 
 
 
 
 
 
26
  # Emotion label mapping
27
  idx2label = {
28
  0: "amusement",
@@ -76,10 +82,7 @@ def predict_emotion(image):
76
  emoji = emotion_emoji.get(emotion, "❓")
77
  return f"{emotion} {emoji}", f"{mem_score:.4f}"
78
 
79
- mem_model = clip_lora_model(output_dim=1).to(device)
80
- mem_model_path = hf_hub_download(repo_id="PerceptCLIP/PerceptCLIP_Memorability", filename="perceptCLIP_Memorability.pth")
81
- mem_model.load_state_dict(torch.load(mem_model_path, map_location=device))
82
- mem_model.eval()
83
 
84
  # Example images
85
  example_images = [
 
14
  modeling = importlib.util.module_from_spec(spec)
15
  spec.loader.exec_module(modeling)
16
 
 
17
  from modeling import clip_lora_model
 
18
 
19
+ # Emotions model
20
+ emotion_model = clip_lora_model().to(device)
21
  emotion_model_path = hf_hub_download(repo_id="PerceptCLIP/PerceptCLIP_Emotions", filename="perceptCLIP_Emotions.pth")
22
  emotion_model.load_state_dict(torch.load(emotion_model_path, map_location=device))
23
  emotion_model.eval()
24
 
25
+ # Memorability model
26
+ mem_model = clip_lora_model(output_dim=1).to(device)
27
+ mem_model_path = hf_hub_download(repo_id="PerceptCLIP/PerceptCLIP_Memorability", filename="perceptCLIP_Memorability.pth")
28
+ mem_model.load_state_dict(torch.load(mem_model_path, map_location=device))
29
+ mem_model.eval()
30
+
31
+
32
  # Emotion label mapping
33
  idx2label = {
34
  0: "amusement",
 
82
  emoji = emotion_emoji.get(emotion, "❓")
83
  return f"{emotion} {emoji}", f"{mem_score:.4f}"
84
 
85
+
 
 
 
86
 
87
  # Example images
88
  example_images = [