gyigit commited on
Commit
0e903f9
·
1 Parent(s): 54e8a79

upload evaluate

Browse files
Files changed (1) hide show
  1. evaluate.py +228 -0
evaluate.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import logging
3
+ import torch.nn.functional as F
4
+ from PIL import Image
5
+ from transformers import AutoTokenizer, AutoModel, Swinv2Model
6
+ from torchvision import transforms
7
+ from src.model.model import MisinformationDetectionModel
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class MisinformationPredictor:
13
+ def __init__(
14
+ self,
15
+ model_path,
16
+ device="cuda" if torch.cuda.is_available() else "cpu",
17
+ embed_dim=256,
18
+ num_heads=8,
19
+ dropout=0.1,
20
+ hidden_dim=64,
21
+ num_classes=3,
22
+ mlp_ratio=4.0,
23
+ text_input_dim=384,
24
+ image_input_dim=1024,
25
+ fused_attn=False,
26
+ text_encoder="microsoft/deberta-v3-xsmall",
27
+ ):
28
+ """
29
+ Initialize the predictor with a trained model and required encoders.
30
+
31
+ Args:
32
+ model_path: Path to the saved model checkpoint
33
+ text_encoder: Name/path of the text encoder model
34
+ device: Device to run inference on
35
+ Other args: Model architecture parameters
36
+ """
37
+ self.device = torch.device(device)
38
+
39
+ # Initialize tokenizer and encoders
40
+ logger.info("Loading encoders...")
41
+ self.tokenizer = AutoTokenizer.from_pretrained(text_encoder)
42
+ self.text_encoder = AutoModel.from_pretrained(text_encoder).to(self.device)
43
+ self.image_encoder = Swinv2Model.from_pretrained(
44
+ "microsoft/swinv2-base-patch4-window8-256"
45
+ ).to(self.device)
46
+
47
+ # Set encoders to eval mode
48
+ self.text_encoder.eval()
49
+ self.image_encoder.eval()
50
+
51
+ # Initialize model
52
+ self.model = MisinformationDetectionModel(
53
+ text_input_dim=text_input_dim,
54
+ image_input_dim=image_input_dim,
55
+ embed_dim=embed_dim,
56
+ num_heads=num_heads,
57
+ dropout=dropout,
58
+ hidden_dim=hidden_dim,
59
+ num_classes=num_classes,
60
+ mlp_ratio=mlp_ratio,
61
+ fused_attn=fused_attn,
62
+ ).to(self.device)
63
+
64
+ # Load model weights
65
+ logger.info(f"Loading model from {model_path}")
66
+ checkpoint = torch.load(model_path, map_location=self.device)
67
+ self.model.load_state_dict(checkpoint["model_state_dict"])
68
+ self.model.eval()
69
+
70
+ # Image preprocessing
71
+ self.image_transform = transforms.Compose(
72
+ [
73
+ transforms.Resize((256, 256)),
74
+ transforms.ToTensor(),
75
+ transforms.Normalize(
76
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
77
+ ),
78
+ ]
79
+ )
80
+
81
+ # Class mapping
82
+ self.idx_to_label = {0: "support", 1: "not_enough_information", 2: "refute"}
83
+
84
+ def process_image(self, image_path):
85
+ """Process image from path to tensor."""
86
+ try:
87
+ image = Image.open(image_path).convert("RGB")
88
+ image = self.image_transform(image).unsqueeze(0) # Add batch dimension
89
+ return image.to(self.device)
90
+ except Exception as e:
91
+ logger.error(f"Error processing image {image_path}: {e}")
92
+ return None
93
+
94
+ @torch.no_grad()
95
+ def evaluate(
96
+ self, claim_text, claim_image_path, evidence_text, evidence_image_path
97
+ ):
98
+ """
99
+ Evaluate a single claim-evidence pair.
100
+
101
+ Args:
102
+ claim_text (str): The claim text
103
+ claim_image_path (str): Path to the claim image
104
+ evidence_text (str): The evidence text
105
+ evidence_image_path (str): Path to the evidence image
106
+
107
+ Returns:
108
+ dict: Dictionary containing predictions from all modality combinations
109
+ """
110
+ try:
111
+ # Process text inputs
112
+ claim_text_inputs = self.tokenizer(
113
+ claim_text,
114
+ truncation=True,
115
+ padding="max_length",
116
+ max_length=512,
117
+ return_tensors="pt",
118
+ ).to(self.device)
119
+
120
+ evidence_text_inputs = self.tokenizer(
121
+ evidence_text,
122
+ truncation=True,
123
+ padding="max_length",
124
+ max_length=512,
125
+ return_tensors="pt",
126
+ ).to(self.device)
127
+
128
+ # Get text embeddings
129
+ claim_text_embeds = self.text_encoder(**claim_text_inputs).last_hidden_state
130
+ evidence_text_embeds = self.text_encoder(
131
+ **evidence_text_inputs
132
+ ).last_hidden_state
133
+
134
+ # Process image inputs
135
+ claim_image = self.process_image(claim_image_path)
136
+ evidence_image = self.process_image(evidence_image_path)
137
+
138
+ # Process claim image
139
+ if claim_image is not None:
140
+ claim_image_embeds = self.image_encoder(claim_image).last_hidden_state
141
+ else:
142
+ logger.warning(
143
+ "Claim image processing failed, setting embedding to None"
144
+ )
145
+ claim_image_embeds = None
146
+
147
+ # Process evidence image
148
+ if evidence_image is not None:
149
+ evidence_image_embeds = self.image_encoder(
150
+ evidence_image
151
+ ).last_hidden_state
152
+ else:
153
+ logger.warning(
154
+ "Evidence image processing failed, setting embedding to None"
155
+ )
156
+ evidence_image_embeds = None
157
+
158
+ # Get model predictions
159
+ (y_t_t, y_t_i), (y_i_t, y_i_i) = self.model(
160
+ X_t=claim_text_embeds,
161
+ X_i=claim_image_embeds,
162
+ E_t=evidence_text_embeds,
163
+ E_i=evidence_image_embeds,
164
+ )
165
+
166
+ # Process predictions with confidence scores
167
+ predictions = {}
168
+
169
+ def process_output(output, path_name):
170
+ if output is not None:
171
+ probs = F.softmax(output, dim=-1)
172
+ pred_idx = probs.argmax(dim=-1).item()
173
+ confidence = probs[0][pred_idx].item()
174
+ return {
175
+ "label": self.idx_to_label[pred_idx],
176
+ "confidence": confidence,
177
+ "probabilities": {
178
+ self.idx_to_label[i]: p.item()
179
+ for i, p in enumerate(probs[0])
180
+ },
181
+ }
182
+ return None
183
+
184
+ predictions["text_text"] = process_output(y_t_t, "text_text")
185
+ predictions["text_image"] = process_output(y_t_i, "text_image")
186
+ predictions["image_text"] = process_output(y_i_t, "image_text")
187
+ predictions["image_image"] = process_output(y_i_i, "image_image")
188
+
189
+ return {
190
+ path: pred["label"] if pred else None
191
+ for path, pred in predictions.items()
192
+ }
193
+
194
+ except Exception as e:
195
+ logger.error(f"Error during evaluation: {e}")
196
+ return None
197
+
198
+
199
+ if __name__ == "__main__":
200
+ # Example usage
201
+ logging.basicConfig(level=logging.INFO)
202
+
203
+ predictor = MisinformationPredictor(model_path="ckpts/model.pt", device="cpu")
204
+
205
+ # Example prediction
206
+ predictions = predictor.evaluate(
207
+ claim_text="Musician Kodak Black was shot outside of a nightclub in Florida in December 2016.",
208
+ claim_image_path="./data/raw/factify/extracted/images/test/0_claim.jpg",
209
+ evidence_text="On 26 December 2016, the web site Gummy Post published an article claiming \
210
+ that musician Kodak Black was shot outside a nightclub in Florida. \
211
+ This article is a hoax. While Gummy Post cited a 'police report', no records exist \
212
+ of any shooting involving Kodak Black (real name Dieuson Octave) in Florida during December 2016. \
213
+ Additionally, the video Gummy Post shared as evidence showed an unrelated crime scene.",
214
+ evidence_image_path="./data/raw/factify/extracted/images/test/0_evidence.jpg",
215
+ )
216
+
217
+ print(predictions)
218
+ # Print predictions
219
+ # if predictions:
220
+ # print("\nPredictions:")
221
+ # for path, pred in predictions.items():
222
+ # if pred:
223
+ # print(f"\n{path}:")
224
+ # print(f" Label: {pred['label']}")
225
+ # print(f" Confidence: {pred['confidence']:.4f}")
226
+ # print(" Probabilities:")
227
+ # for label, prob in pred["probabilities"].items():
228
+ # print(f" {label}: {prob:.4f}")