Amitz244 commited on
Commit
9084867
·
verified ·
1 Parent(s): 57c275a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -8
app.py CHANGED
@@ -28,6 +28,11 @@ mem_model_path = hf_hub_download(repo_id="PerceptCLIP/PerceptCLIP_Memorability",
28
  mem_model.load_state_dict(torch.load(mem_model_path, map_location=device))
29
  mem_model.eval()
30
 
 
 
 
 
 
31
 
32
  # Emotion label mapping
33
  idx2label = {
@@ -54,7 +59,7 @@ emotion_emoji = {
54
  }
55
 
56
  # Image preprocessing
57
- def emo_preprocess(image):
58
  transform = transforms.Compose([
59
  transforms.Resize(224),
60
  transforms.CenterCrop(224),
@@ -63,24 +68,45 @@ def emo_preprocess(image):
63
  ])
64
  return transform(image).unsqueeze(0).to(device)
65
 
 
 
 
 
 
 
 
 
 
 
 
66
  # Inference function
67
- def predict_emotion(image):
68
  # If the image is passed as a PIL Image
69
  if isinstance(image, Image.Image):
70
  img = image.convert("RGB")
71
  else:
72
- img = Image.open(image).convert("RGB")
73
-
74
- img = emo_preprocess(img)
 
75
 
76
  with torch.no_grad():
 
77
  mem_score = mem_model(img).item()
78
  outputs = emotion_model(img)
79
  predicted = outputs.argmax(1).item()
 
 
 
 
 
 
 
 
80
 
81
  emotion = idx2label[predicted]
82
  emoji = emotion_emoji.get(emotion, "❓")
83
- return f"{emotion} {emoji}", f"{mem_score:.4f}"
84
 
85
 
86
 
@@ -94,9 +120,9 @@ example_images = [
94
 
95
  # Create Gradio interface with custom CSS
96
  iface = gr.Interface(
97
- fn=predict_emotion,
98
  inputs=gr.Image(type="pil", label="Upload an Image"),
99
- outputs=[gr.Textbox(label="Emotion"), gr.Textbox(label="Memorability Score")],
100
  title="PerceptCLIP",
101
  description="This is an official demo of PerceptCLIP from the paper: [Don’t Judge Before You CLIP: A Unified Approach for Perceptual Tasks](https://arxiv.org/pdf/2503.13260). For each specific task, we fine-tune CLIP with LoRA and an MLP head. Our models achieve state-of-the-art performance. \nThis demo shows results from three models, each corresponding to a different task - visual emotion analysis, memorability prediction, and image quality assessment.",
102
  examples=example_images
 
28
  mem_model.load_state_dict(torch.load(mem_model_path, map_location=device))
29
  mem_model.eval()
30
 
31
+ # IQA model
32
+ iqa_model = clip_lora_model(output_dim=1).to(device)
33
+ iqa_model_path = hf_hub_download(repo_id="PerceptCLIP/PerceptCLIP_IQA", filename="perceptCLIP_IQA.pth")
34
+ iqa_model.load_state_dict(torch.load(iqa_model_path, map_location=device))
35
+ iqa_model.eval()
36
 
37
  # Emotion label mapping
38
  idx2label = {
 
59
  }
60
 
61
  # Image preprocessing
62
+ def emo_mem_preprocess(image):
63
  transform = transforms.Compose([
64
  transforms.Resize(224),
65
  transforms.CenterCrop(224),
 
68
  ])
69
  return transform(image).unsqueeze(0).to(device)
70
 
71
+ def IQA_preprocess():
72
+ random.seed(3407)
73
+ transform = transforms.Compose([
74
+ transforms.Resize((512,384)),
75
+ transforms.RandomCrop(size=(224,224)),
76
+ transforms.ToTensor(),
77
+ transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
78
+ std=(0.26862954, 0.26130258, 0.27577711))
79
+ ])
80
+ return transform
81
+
82
  # Inference function
83
+ def predict_percept(image):
84
  # If the image is passed as a PIL Image
85
  if isinstance(image, Image.Image):
86
  img = image.convert("RGB")
87
  else:
88
+ img = Image.open(image).convert("RGB")
89
+
90
+ batch = torch.stack([IQA_preprocess()(image) for _ in range(15)]).to(device) # Shape: (15, 3, 224, 224)
91
+ img = emo_mem_preprocess(img)
92
 
93
  with torch.no_grad():
94
+ iqa_score = model(batch).cpu().numpy()
95
  mem_score = mem_model(img).item()
96
  outputs = emotion_model(img)
97
  predicted = outputs.argmax(1).item()
98
+
99
+ iqa_score = np.mean(scores)
100
+ min_iqa_pred = -100
101
+ max_iqa_pred = 100
102
+ max_iqa_score = 0
103
+ min_iqa_score = 5
104
+
105
+ normalized_iqa_score = ((iqa_score - min_iqa_pred) / (max_iqa_pred - min_iqa_pred)) * (max_iqa_score - min_iqa_score) + min_iqa_score
106
 
107
  emotion = idx2label[predicted]
108
  emoji = emotion_emoji.get(emotion, "❓")
109
+ return f"{emotion} {emoji}", f"{mem_score:.4f}", f"{normalized_iqa_score:.4f}"
110
 
111
 
112
 
 
120
 
121
  # Create Gradio interface with custom CSS
122
  iface = gr.Interface(
123
+ fn=predict_percept,
124
  inputs=gr.Image(type="pil", label="Upload an Image"),
125
+ outputs=[gr.Textbox(label="Emotion"), gr.Textbox(label="Memorability Score"), gr.Textbox(label="IQA Score")]
126
  title="PerceptCLIP",
127
  description="This is an official demo of PerceptCLIP from the paper: [Don’t Judge Before You CLIP: A Unified Approach for Perceptual Tasks](https://arxiv.org/pdf/2503.13260). For each specific task, we fine-tune CLIP with LoRA and an MLP head. Our models achieve state-of-the-art performance. \nThis demo shows results from three models, each corresponding to a different task - visual emotion analysis, memorability prediction, and image quality assessment.",
128
  examples=example_images