cloneofsimo commited on
Commit
b72fefd
·
verified ·
1 Parent(s): 7e0c0e8

Upload folder using huggingface_hub

Browse files
Files changed (7) hide show
  1. example.py +203 -0
  2. stage_1.py +199 -0
  3. stage_2.py +292 -0
  4. stage_4.py +506 -0
  5. streamlit_evaluation_app.py +695 -0
  6. upload_to_hf.py +112 -0
  7. validation_runner.py +286 -0
example.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Example inference script for the Multi-Head SigLIP2 Classifier from Hugging Face Hub.
4
+
5
+ Usage examples:
6
+ # Multiple images, single text
7
+ python example.py --image img1.png --image img2.jpg --repo fal/multihead_cls --text "an example caption"
8
+
9
+ # N images, N texts (returns an N x N similarity matrix)
10
+ python example.py \
11
+ --image img1.png --image img2.jpg \
12
+ --text "a cat" --text "a dog" --repo fal/multihead_cls
13
+
14
+ Requires: torch, transformers, huggingface_hub, Pillow, click
15
+ """
16
+
17
+ import json
18
+ import click
19
+ import torch
20
+ from PIL import Image
21
+ from transformers import AutoProcessor
22
+ from huggingface_hub import hf_hub_download
23
+
24
+ # Local model definition replicated from training for easy inference
25
+ import torch.nn as nn
26
+ from transformers import SiglipModel
27
+ import torch.nn.functional as F
28
+
29
+ CKPT = "google/siglip-base-patch16-256"
30
+
31
+ class MultiHeadSiglipClassifier(nn.Module):
32
+ """Dynamic multi-head classifier based on task configuration"""
33
+ def __init__(self, task_config: dict, model_name: str = CKPT):
34
+ super().__init__()
35
+ self.task_config = task_config
36
+ self.siglip = SiglipModel.from_pretrained(model_name)
37
+
38
+ # Freeze SigLIP parameters
39
+ for param in self.siglip.parameters():
40
+ param.requires_grad = False
41
+
42
+ # Create classification heads dynamically based on task config
43
+ hidden_size = self.siglip.config.vision_config.hidden_size
44
+ self.classification_heads = nn.ModuleDict()
45
+
46
+ for task in task_config['tasks']:
47
+ task_key = task['key']
48
+ num_classes = len(task['labels'])
49
+
50
+ # Create linear layer for this task
51
+ head = nn.Linear(hidden_size, num_classes)
52
+ self.classification_heads[task_key] = head
53
+
54
+ def forward(self, pixel_values):
55
+ # Get SigLIP image embeddings only
56
+ combined_embeds = self.siglip.get_image_features(pixel_values=pixel_values)
57
+
58
+ # Apply all classification heads
59
+ outputs = {}
60
+ for task_key, head in self.classification_heads.items():
61
+ outputs[task_key] = head(combined_embeds)
62
+
63
+ return outputs
64
+
65
+
66
+ def load_model_from_hf(repo_id: str):
67
+ """Load model, processor, and task config from Hugging Face Hub"""
68
+ # Download task configuration
69
+ try:
70
+ task_config_path = hf_hub_download(repo_id=repo_id, filename="task_config.json", repo_type="model")
71
+ with open(task_config_path, 'r') as f:
72
+ task_config = json.load(f)
73
+ except Exception as e:
74
+ raise RuntimeError(f"Could not load task_config.json from {repo_id}: {e}")
75
+
76
+ # Load processor
77
+ processor = AutoProcessor.from_pretrained(CKPT)
78
+
79
+ # Create model with task config
80
+ model = MultiHeadSiglipClassifier(task_config)
81
+
82
+ # Load trained weights
83
+ try:
84
+ ckpt_path = hf_hub_download(repo_id=repo_id, filename="model.pth", repo_type="model")
85
+ state_dict = torch.load(ckpt_path, map_location="cpu")
86
+ model.load_state_dict(state_dict)
87
+ except Exception as e:
88
+ raise RuntimeError(f"Could not load model.pth from {repo_id}: {e}")
89
+
90
+ model.eval()
91
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
92
+ model.to(device)
93
+
94
+ return model, processor, device, task_config
95
+
96
+
97
+ def predict_batch(model, processor, device, task_config, image_paths, texts: list[str] | None = None):
98
+ """Run predictions on a batch of images using dynamic task configuration"""
99
+ images = [Image.open(p).convert("RGB") for p in image_paths]
100
+ if texts is not None and len(texts) == 0:
101
+ texts = None
102
+
103
+ # Process images
104
+ image_inputs = processor(images=images, return_tensors="pt")
105
+ pixel_values = image_inputs["pixel_values"].to(device)
106
+
107
+ with torch.no_grad():
108
+ outputs = model(pixel_values)
109
+ # Compute image embeddings for similarity
110
+ image_embeds = model.siglip.get_image_features(pixel_values=pixel_values)
111
+ image_embeds = F.normalize(image_embeds, p=2, dim=-1)
112
+
113
+ # Prepare text inputs if provided
114
+ text_embeds = None
115
+ input_ids = None
116
+ attention_mask = None
117
+ if texts is not None:
118
+ text_inputs = processor(text=texts, padding="max_length", return_tensors="pt")
119
+ input_ids = text_inputs["input_ids"].to(device)
120
+ attention_mask = text_inputs.get("attention_mask")
121
+ attention_mask = attention_mask.to(device) if attention_mask is not None else None
122
+ text_embeds = model.siglip.get_text_features(input_ids=input_ids, attention_mask=attention_mask)
123
+ text_embeds = F.normalize(text_embeds, p=2, dim=-1)
124
+
125
+ # Create task mappings
126
+ tasks = {task['key']: task for task in task_config['tasks']}
127
+
128
+ batch_results = []
129
+ batch_size = pixel_values.shape[0]
130
+
131
+ for i in range(batch_size):
132
+ item = {"image": str(image_paths[i])}
133
+
134
+ # Process each task dynamically
135
+ for task_key, task_info in tasks.items():
136
+ logits = outputs[task_key][i]
137
+ probs = torch.softmax(logits, dim=0)
138
+ pred_idx = torch.argmax(probs).item()
139
+
140
+ if task_info['type'] == 'binary':
141
+ # Binary classification
142
+ item[f"{task_key}_prediction"] = task_info['labels'][pred_idx]
143
+ item[f"{task_key}_confidence"] = float(probs[pred_idx].item())
144
+ item[f"{task_key}_prob_yes"] = float(probs[1].item()) if len(task_info['labels']) > 1 else 0.0
145
+ item[f"{task_key}_prob_no"] = float(probs[0].item())
146
+
147
+ elif task_info['type'] == 'multi_class':
148
+ # Multi-class classification
149
+ item[f"{task_key}_prediction"] = task_info['labels'][pred_idx]
150
+ item[f"{task_key}_confidence"] = float(probs[pred_idx].item())
151
+
152
+ # Add probabilities for all classes
153
+ for idx, label in enumerate(task_info['labels']):
154
+ item[f"{task_key}_prob_{label}"] = float(probs[idx].item())
155
+
156
+ batch_results.append(item)
157
+
158
+ cosine_matrix = None
159
+
160
+ if input_ids is not None:
161
+ # These embeds are already L2-normalized inside SigLIP forward
162
+ cosine = torch.matmul(image_embeds, text_embeds.T)
163
+ cosine_matrix = cosine.cpu().tolist()
164
+
165
+ return {
166
+ "images": [str(p) for p in image_paths],
167
+ "texts": texts or [],
168
+ "task_config": task_config,
169
+ "predictions": batch_results,
170
+ "cosine_similarity": cosine_matrix,
171
+ }
172
+
173
+
174
+ @click.command()
175
+ @click.option("--image", "images", multiple=True, type=click.Path(exists=True, dir_okay=False, readable=True), help="Path(s) to image file(s). Can be passed multiple times.")
176
+ @click.option("--repo", default="fal/multihead_cls", show_default=True, help="Hugging Face repo id with model checkpoint.")
177
+ @click.option("--text", "texts", multiple=True, help="Text prompt(s). Can be passed multiple times to build an N x N image-text similarity matrix.")
178
+ @click.option("--show-tasks", is_flag=True, help="Show available classification tasks and exit.")
179
+ def cli(images, repo, texts, show_tasks):
180
+ """Multi-head SigLIP2 classifier inference from Hugging Face Hub"""
181
+
182
+ # Load model and task config
183
+ model, processor, device, task_config = load_model_from_hf(repo)
184
+
185
+ if show_tasks:
186
+ click.echo("Available classification tasks:")
187
+ for i, task in enumerate(task_config['tasks'], 1):
188
+ click.echo(f" {i}. {task['name']} ({task['key']})")
189
+ click.echo(f" Type: {task['type']}")
190
+ click.echo(f" Labels: {', '.join(task['labels'])}")
191
+ click.echo(f" Description: {task['description']}")
192
+ click.echo()
193
+ return
194
+
195
+ if not images:
196
+ images = ("img.png",)
197
+
198
+ results = predict_batch(model, processor, device, task_config, list(images), texts=list(texts) if texts else None)
199
+ click.echo(json.dumps(results, indent=2))
200
+
201
+
202
+ if __name__ == "__main__":
203
+ cli()
stage_1.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Stage 1: Data Loading and Image Downloading
4
+ Downloads and preprocesses top 2000 images from parquet file
5
+ """
6
+
7
+ import os
8
+ import json
9
+ import requests
10
+ import pandas as pd
11
+ from PIL import Image
12
+ from io import BytesIO
13
+ import concurrent.futures
14
+ from pathlib import Path
15
+ import time
16
+ import logging
17
+ import numpy as np
18
+ from typing import Tuple
19
+
20
+ # Set up logging
21
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
22
+ logger = logging.getLogger(__name__)
23
+
24
+ def setup_environment():
25
+ """Setup data directory"""
26
+ os.makedirs('./data', exist_ok=True)
27
+ os.makedirs('./data/images', exist_ok=True)
28
+ os.makedirs('./data/metadata', exist_ok=True)
29
+ return True
30
+
31
+ def load_and_sample_data(parquet_path: str, n_samples: int = 2000) -> pd.DataFrame:
32
+ """Load parquet file and sample top N rows"""
33
+ logger.info(f"Loading data from {parquet_path}")
34
+ df = pd.read_parquet(parquet_path)
35
+ logger.info(f"Loaded {len(df)} rows, sampling top {n_samples}")
36
+ return df.head(n_samples)
37
+
38
+ def has_white_edges(img: Image.Image, threshold: int = 240) -> bool:
39
+ """Check if image has 3 or more white edges (mean RGB > threshold)"""
40
+ try:
41
+ img_array = np.array(img)
42
+ height, width = img_array.shape[:2]
43
+
44
+ # Define edge thickness (check 5 pixels from each edge)
45
+ edge_thickness = 5
46
+
47
+ # Get edges
48
+ top_edge = img_array[:edge_thickness, :].mean(axis=(0, 1))
49
+ bottom_edge = img_array[-edge_thickness:, :].mean(axis=(0, 1))
50
+ left_edge = img_array[:, :edge_thickness].mean(axis=(0, 1))
51
+ right_edge = img_array[:, -edge_thickness:].mean(axis=(0, 1))
52
+
53
+ # Check if edge is white (all RGB channels > threshold)
54
+ edges = [top_edge, bottom_edge, left_edge, right_edge]
55
+ white_edges = sum(1 for edge in edges if np.all(edge > threshold))
56
+
57
+ return white_edges >= 3
58
+ except Exception as e:
59
+ logger.debug(f"Error checking white edges: {e}")
60
+ return False
61
+
62
+ def download_and_process_image(url: str, target_size: int = 256) -> Image.Image:
63
+ """Download image and resize with center crop, skip if has white edges"""
64
+ try:
65
+ response = requests.get(url, timeout=10, headers={'User-Agent': 'Mozilla/5.0'})
66
+ response.raise_for_status()
67
+
68
+
69
+ img = Image.open(BytesIO(response.content)).convert('RGB')
70
+
71
+ # Check for white edges before processing
72
+ if has_white_edges(img):
73
+ logger.debug(f"Skipping image with white edges: {url}")
74
+ return None
75
+
76
+ # Resize and center crop to target_size x target_size
77
+ width, height = img.size
78
+ min_side = min(width, height)
79
+ scale = target_size / min_side
80
+
81
+ new_width = int(width * scale)
82
+ new_height = int(height * scale)
83
+
84
+ img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
85
+
86
+ # Center crop
87
+ left = (new_width - target_size) // 2
88
+ top = (new_height - target_size) // 2
89
+ right = left + target_size
90
+ bottom = top + target_size
91
+
92
+ img = img.crop((left, top, right, bottom))
93
+
94
+ # Double-check after processing
95
+ if has_white_edges(img):
96
+ logger.debug(f"Skipping processed image with white edges: {url}")
97
+ return None
98
+
99
+ return img
100
+ except Exception as e:
101
+ logger.error(f"Error downloading {url}: {e}")
102
+ return None
103
+
104
+ def process_single_image(args: Tuple[int, str, str, str]) -> bool:
105
+ """Download and save a single image"""
106
+ idx, url, hash_val, caption = args
107
+
108
+ try:
109
+ # Download and process image
110
+ image = download_and_process_image(url)
111
+ if image is None:
112
+ logger.debug(f"Skipped image {idx} (white edges or download error)")
113
+ return False
114
+
115
+ # Save image
116
+ image_path = f'./data/images/img_{idx}.png'
117
+ image.save(image_path)
118
+
119
+ # Save metadata for next stage
120
+ metadata = {
121
+ "idx": idx,
122
+ "caption": caption,
123
+ "url": url,
124
+ "hash": hash_val,
125
+ "image_path": image_path
126
+ }
127
+
128
+ metadata_path = f'./data/metadata/meta_{idx}.json'
129
+ with open(metadata_path, 'w') as f:
130
+ json.dump(metadata, f, indent=2)
131
+
132
+ logger.info(f"Downloaded and saved image {idx}")
133
+ return True
134
+
135
+ except Exception as e:
136
+ logger.error(f"Error processing image {idx}: {e}")
137
+ return False
138
+
139
+ def download_images(df: pd.DataFrame, max_workers: int = 20):
140
+ """Download all images with parallel processing"""
141
+ logger.info(f"Starting image download with {max_workers} workers...")
142
+
143
+ args_list = [(i, row['url'], row['hash'], row['caption'])
144
+ for i, (_, row) in enumerate(df.iterrows())]
145
+
146
+ successful = 0
147
+ skipped_white_edges = 0
148
+
149
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
150
+ futures = [executor.submit(process_single_image, args) for args in args_list]
151
+
152
+ for i, future in enumerate(concurrent.futures.as_completed(futures)):
153
+ if future.result():
154
+ successful += 1
155
+ else:
156
+ skipped_white_edges += 1
157
+
158
+ # Progress logging every 100 images
159
+ if (i + 1) % 100 == 0:
160
+ logger.info(f"Processed {i + 1}/{len(args_list)} images (successful: {successful}, skipped: {skipped_white_edges})")
161
+
162
+ # Minimal rate limiting for high concurrency
163
+ time.sleep(0.01)
164
+
165
+ logger.info(f"Download complete: {successful}/{len(args_list)} images downloaded, {skipped_white_edges} skipped (white edges)")
166
+
167
+ # Save summary
168
+ summary = {
169
+ "total_images": len(args_list),
170
+ "successful_downloads": successful,
171
+ "skipped_white_edges": skipped_white_edges,
172
+ "download_rate": f"{successful/len(args_list)*100:.1f}%",
173
+ "stage": "download_complete"
174
+ }
175
+
176
+ with open('./data/stage1_summary.json', 'w') as f:
177
+ json.dump(summary, f, indent=2)
178
+
179
+ def main():
180
+ """Main execution for Stage 1"""
181
+ logger.info("Starting Stage 1: Data Loading and Image Downloading...")
182
+
183
+ # Setup
184
+ setup_environment()
185
+
186
+ # Load data
187
+ parquet_path = '/home/fal/partiprompt_clip/curated_part_00000.parquet'
188
+ df = load_and_sample_data(parquet_path, n_samples=5000)
189
+
190
+ # Save the dataframe for other stages
191
+ df.to_pickle('./data/sampled_data.pkl')
192
+
193
+ # Download images with optimized settings
194
+ download_images(df, max_workers=30)
195
+
196
+ logger.info("Stage 1 completed successfully!")
197
+
198
+ if __name__ == "__main__":
199
+ main()
stage_2.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Stage 2: Gemini Vision Classification
4
+ Classifies images using Google Gemini with 5 classification tasks
5
+ """
6
+
7
+ import os
8
+ import json
9
+ from PIL import Image
10
+ from io import BytesIO
11
+ import concurrent.futures
12
+ from pathlib import Path
13
+ import time
14
+ import logging
15
+ from typing import Dict, Any
16
+ import mimetypes
17
+ import random
18
+
19
+ # Gemini SDK
20
+ from google import genai
21
+ from google.genai.errors import ServerError
22
+ from google.genai.types import (
23
+ Blob, Part, Content, GenerateContentConfig,
24
+ )
25
+
26
+ # Set up logging
27
+ logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
28
+ logger = logging.getLogger(__name__)
29
+
30
+ GEMINI_API_KEY_FALLBACK = "AIzaSyBCgkB2nRaRNgbl06MBu1I_xHiuXSUQMHA"
31
+
32
+
33
+ def check_api_key():
34
+ """Ensure Google API key is set for Gemini client."""
35
+ if not os.getenv('GOOGLE_API_KEY'):
36
+ # Use provided key as fallback if env not set
37
+ os.environ['GOOGLE_API_KEY'] = GEMINI_API_KEY_FALLBACK
38
+ return True
39
+
40
+ def _guess_mime_type(image_path: str) -> str:
41
+ guessed, _ = mimetypes.guess_type(image_path)
42
+ if guessed:
43
+ return guessed
44
+ try:
45
+ with Image.open(image_path) as im:
46
+ fmt = (im.format or '').lower()
47
+ if fmt in ('jpeg', 'jpg'):
48
+ return 'image/jpeg'
49
+ if fmt == 'png':
50
+ return 'image/png'
51
+ if fmt == 'webp':
52
+ return 'image/webp'
53
+ if fmt == 'gif':
54
+ return 'image/gif'
55
+ except Exception:
56
+ pass
57
+ return 'application/octet-stream'
58
+
59
+ def _gemini_call_with_retry(contents, cfg, max_attempts: int = 5):
60
+ """Call Gemini with retries on server/errors."""
61
+ api_key = os.getenv('GOOGLE_API_KEY') or GEMINI_API_KEY_FALLBACK
62
+ for attempt in range(max_attempts):
63
+ try:
64
+ client = genai.Client(api_key=api_key)
65
+ return client.models.generate_content(
66
+ model="models/gemini-2.5-flash",
67
+ contents=contents,
68
+ config=cfg,
69
+ )
70
+ except ServerError as e:
71
+ sleep_s = (2 ** attempt) + random.random()
72
+ logger.warning(f"Gemini server error attempt {attempt+1}/{max_attempts}: {e}; retrying in {sleep_s:.1f}s")
73
+ time.sleep(sleep_s)
74
+ except Exception as e:
75
+ sleep_s = (2 ** attempt) + random.random()
76
+ logger.warning(f"Gemini error attempt {attempt+1}/{max_attempts}: {e}; retrying in {sleep_s:.1f}s")
77
+ time.sleep(sleep_s)
78
+ raise RuntimeError(f"Persistent Gemini error after {max_attempts} tries")
79
+
80
+
81
+ def classify_image_with_gemini(image_path: str, caption: str, max_retries: int = 3) -> Dict[str, Any]:
82
+ """Use Google Gemini to classify an image with structured JSON output."""
83
+ prompt = f"""
84
+ Analyze this image with caption: "{caption}"
85
+
86
+ Please answer the following 5 classification questions and respond ONLY with valid JSON:
87
+
88
+ 1. Overall Description.
89
+ 2. Is the image product display / low quality advertisement / e-commerce product? Answer: "yes" or "no"
90
+ 3. Is the image computer screenshot with many text overlays? Answer: "yes" or "no"
91
+ 4. In what category is the image? Choose one from: "animals", "artifacts", "people", "outdoor_scenes", "illustrations", "vehicles", "food_and_beverage", "arts", "abstract", "produce_and_plants", "indoor_scenes"
92
+ 5. Would you say the image is interesting? Answer: "yes" or "no"
93
+ 6. Do you think the photo/image was made by a professional photographer? Answer: "yes" or "no"
94
+
95
+ IMPORTANT: Respond with ONLY a valid JSON object with these exact keys. Do not include any other text or explanation:
96
+
97
+ {{
98
+ "overall_description": "...",
99
+ "is_product_advertisement": "yes",
100
+ "is_screenshot_with_text": "no",
101
+ "category": "animals",
102
+ "is_interesting": "no",
103
+ "is_professional": "yes"
104
+ }}
105
+ """
106
+
107
+ default_response = {
108
+ "overall_description": "...",
109
+ "is_product_advertisement": "...",
110
+ "is_screenshot_with_text": "...",
111
+ "category": "...",
112
+ "is_interesting": "...",
113
+ "is_professional": "..."
114
+ }
115
+
116
+ try:
117
+ with open(image_path, 'rb') as f:
118
+ image_bytes = f.read()
119
+ mime_type = _guess_mime_type(image_path)
120
+
121
+ image_blob = Blob(mime_type=mime_type, data=image_bytes)
122
+ user_content = Content(
123
+ role="user",
124
+ parts=[
125
+ Part(text=prompt),
126
+ Part(inline_data=image_blob),
127
+ ],
128
+ )
129
+ contents = [user_content]
130
+ cfg = GenerateContentConfig(max_output_tokens=2500, temperature=0)
131
+
132
+ resp = _gemini_call_with_retry(contents, cfg, max_attempts=max_retries)
133
+ logger.debug(f"Gemini response type: {type(resp)}")
134
+
135
+ # Detailed debugging of response structure
136
+ logger.debug(f"Response.text: {getattr(resp, 'text', 'NO_TEXT_ATTR')}")
137
+ logger.debug(f"Response.candidates: {getattr(resp, 'candidates', 'NO_CANDIDATES_ATTR')}")
138
+ if hasattr(resp, 'candidates') and resp.candidates:
139
+ logger.debug(f"Number of candidates: {len(resp.candidates)}")
140
+ for i, candidate in enumerate(resp.candidates):
141
+ logger.debug(f"Candidate {i}: {candidate}")
142
+ if hasattr(candidate, 'content'):
143
+ logger.debug(f"Candidate {i} content: {candidate.content}")
144
+ if hasattr(candidate.content, 'parts'):
145
+ logger.debug(f"Candidate {i} parts: {candidate.content.parts}")
146
+
147
+ # Check for prompt_feedback which might indicate filtering
148
+ if hasattr(resp, 'prompt_feedback'):
149
+ logger.debug(f"Prompt feedback: {resp.prompt_feedback}")
150
+
151
+ # Extract text from Gemini response
152
+ content_text = ""
153
+ try:
154
+ # Try the .text property first
155
+ if hasattr(resp, 'text') and resp.text:
156
+ content_text = resp.text
157
+ logger.debug(f"Got text from .text property: {content_text[:100]}...")
158
+ else:
159
+ # Fallback: extract from candidates
160
+ if resp.candidates and len(resp.candidates) > 0:
161
+ candidate = resp.candidates[0]
162
+ if hasattr(candidate, 'content') and candidate.content:
163
+ if hasattr(candidate.content, 'parts') and candidate.content.parts:
164
+ for part in candidate.content.parts:
165
+ if hasattr(part, 'text') and part.text:
166
+ content_text += part.text
167
+ logger.debug(f"Got text from candidate part: {part.text[:100]}...")
168
+ except Exception as e:
169
+ logger.error(f"Error extracting text from Gemini response: {e}")
170
+ raise e
171
+
172
+ if not content_text:
173
+ logger.error(f"Empty response from Gemini")
174
+ return default_response
175
+
176
+ content_text = content_text.strip()
177
+ start_idx = content_text.find('{')
178
+ end_idx = content_text.rfind('}') + 1
179
+ if start_idx == -1 or end_idx == 0:
180
+ logger.error(f"No JSON found in response: {content_text}")
181
+ return default_response
182
+
183
+ json_content = content_text[start_idx:end_idx]
184
+ classification = json.loads(json_content)
185
+
186
+ required_keys = [
187
+ "overall_description",
188
+ "is_product_advertisement",
189
+ "is_screenshot_with_text",
190
+ "category",
191
+ "is_interesting",
192
+ "is_professional",
193
+ ]
194
+ missing_keys = [key for key in required_keys if key not in classification]
195
+ if missing_keys:
196
+ logger.warning(f"Missing keys in classification: {missing_keys}")
197
+ for key in missing_keys:
198
+ classification[key] = default_response[key]
199
+
200
+ return classification
201
+ except json.JSONDecodeError as e:
202
+ logger.error(f"JSON parsing error: {e}")
203
+ return default_response
204
+ except Exception as e:
205
+ logger.error(f"Gemini classification error: {e}")
206
+ return default_response
207
+
208
+ def classify_single_image(metadata_file: Path) -> bool:
209
+ """Classify a single image and save results"""
210
+ try:
211
+ # Load metadata
212
+ with open(metadata_file, 'r', encoding='utf-8') as f:
213
+ metadata = json.load(f)
214
+
215
+ idx = metadata['idx']
216
+ image_path = metadata['image_path']
217
+ caption = metadata['caption']
218
+
219
+ # Check if image exists
220
+ if not os.path.exists(image_path):
221
+ logger.error(f"Image not found: {image_path}")
222
+ return False
223
+
224
+ # Classify with Gemini
225
+ classification = classify_image_with_gemini(image_path, caption)
226
+
227
+ # Add classification to metadata
228
+ metadata['classification'] = classification
229
+ metadata['stage2_complete'] = True
230
+
231
+ # Save updated metadata
232
+ new_metadata_file = metadata_file.with_name(f'meta_{idx}_stage2.json')
233
+ with open(new_metadata_file, 'w', encoding='utf-8') as f:
234
+ json.dump(metadata, f, indent=2, ensure_ascii=False)
235
+
236
+ logger.info(f"Classified image {idx}")
237
+ return True
238
+
239
+ except Exception as e:
240
+ logger.error(f"Error classifying {metadata_file}: {e}")
241
+ return False
242
+
243
+ def classify_all_images(max_workers: int = 2):
244
+ """Classify all downloaded images with parallel processing"""
245
+ logger.info("Starting image classification...")
246
+
247
+ # Get all metadata files
248
+ metadata_dir = Path('./data/metadata')
249
+ metadata_files = list(metadata_dir.glob('meta_*.json'))
250
+
251
+ if not metadata_files:
252
+ logger.error("No metadata files found. Run stage 1 first.")
253
+ return
254
+
255
+ successful = 0
256
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
257
+ futures = [executor.submit(classify_single_image, meta_file) for meta_file in metadata_files]
258
+
259
+ for future in concurrent.futures.as_completed(futures):
260
+ if future.result():
261
+ successful += 1
262
+
263
+ # Rate limiting for API calls
264
+ time.sleep(1.0) # 1 second between API calls to avoid rate limits
265
+
266
+ logger.info(f"Successfully classified {successful}/{len(metadata_files)} images")
267
+
268
+ # Save summary
269
+ summary = {
270
+ "total_images": len(metadata_files),
271
+ "successful_classifications": successful,
272
+ "stage": "classification_complete"
273
+ }
274
+
275
+ with open('./data/stage2_summary.json', 'w') as f:
276
+ json.dump(summary, f, indent=2)
277
+
278
+ def main():
279
+ """Main execution for Stage 2"""
280
+ logger.info("Starting Stage 2: Gemini Vision Classification...")
281
+
282
+ # Check API key
283
+ if not check_api_key():
284
+ return
285
+
286
+ # Classify images
287
+ classify_all_images(max_workers=64) # Reduced to avoid rate limits
288
+
289
+ logger.info("Stage 2 completed successfully!")
290
+
291
+ if __name__ == "__main__":
292
+ main()
stage_4.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Stage 4: SigLIP v2 Multi-Head Classifier Training
4
+ Trains a SigLIP v2-based multi-head classifier on pseudo-labeled data
5
+ """
6
+
7
+ import os
8
+ import json
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.optim as optim
12
+ from torch.utils.data import Dataset, DataLoader
13
+ from transformers import SiglipModel, AutoProcessor
14
+ import numpy as np
15
+ from PIL import Image
16
+ from pathlib import Path
17
+ import logging
18
+ from typing import Dict, List, Any
19
+ import pickle
20
+ import matplotlib.pyplot as plt
21
+ from torch.optim.lr_scheduler import LambdaLR
22
+
23
+ # Set up logging
24
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
25
+ logger = logging.getLogger(__name__)
26
+
27
+ CKPT = "google/siglip-base-patch16-256"
28
+
29
+ def load_task_config(config_path: str = './task_config.json'):
30
+ """Load task configuration from JSON file"""
31
+ if not os.path.exists(config_path):
32
+ raise FileNotFoundError(f"Task configuration not found: {config_path}")
33
+
34
+ with open(config_path, 'r') as f:
35
+ config = json.load(f)
36
+
37
+ logger.info(f"Loaded task configuration with {len(config['tasks'])} tasks")
38
+ return config
39
+
40
+ class MultiHeadDataset(Dataset):
41
+ """Dataset for multi-head classification with configurable tasks"""
42
+ def __init__(self, data_dir: str, processor, task_config: Dict):
43
+ self.data_dir = Path(data_dir)
44
+ self.processor = processor
45
+ self.task_config = task_config
46
+
47
+ # Load all metadata files from stage 2 (with _stage2 suffix)
48
+ metadata_dir = self.data_dir / 'metadata'
49
+ if not metadata_dir.exists():
50
+ raise FileNotFoundError("Metadata directory not found. Run stages 1 and 2 first.")
51
+
52
+ metadata_files = list(metadata_dir.glob('meta_*_stage2.json'))
53
+ if not metadata_files:
54
+ raise FileNotFoundError("No stage 2 metadata files found. Run stage 2 first.")
55
+
56
+ # Load all samples
57
+ self.samples = []
58
+ skipped_incomplete = 0
59
+
60
+ for meta_file in metadata_files:
61
+ try:
62
+ with open(meta_file, 'r') as f:
63
+ metadata = json.load(f)
64
+
65
+ # Check if classification is complete
66
+ if not metadata.get('stage2_complete', False):
67
+ logger.warning(f"Skipping {meta_file} - classification not complete")
68
+ skipped_incomplete += 1
69
+ continue
70
+
71
+ # Check if classification contains incomplete data (empty or "..." values)
72
+ classification = metadata.get('classification', {})
73
+ if not classification or self._is_incomplete_classification(classification):
74
+ logger.warning(f"Skipping {meta_file} - incomplete classification data")
75
+ skipped_incomplete += 1
76
+ continue
77
+
78
+ # Check if image exists
79
+ image_path = metadata['image_path']
80
+ if not os.path.exists(image_path):
81
+ logger.warning(f"Image not found: {image_path}")
82
+ skipped_incomplete += 1
83
+ continue
84
+
85
+ self.samples.append(metadata)
86
+
87
+ except Exception as e:
88
+ logger.error(f"Error loading {meta_file}: {e}")
89
+ skipped_incomplete += 1
90
+
91
+ # Create label mappings from task config
92
+ self.label_mappings = {}
93
+ for task in self.task_config['tasks']:
94
+ if task['type'] == 'multi_class':
95
+ self.label_mappings[task['key']] = {
96
+ label: idx for idx, label in enumerate(task['labels'])
97
+ }
98
+
99
+ if skipped_incomplete > 0:
100
+ logger.warning(f"Skipped {skipped_incomplete} incomplete samples")
101
+ logger.info(f"Loaded {len(self.samples)} valid samples for training")
102
+
103
+ def _is_incomplete_classification(self, classification: Dict) -> bool:
104
+ """Check if classification contains incomplete data (empty or '...' values)"""
105
+ required_tasks = [task['key'] for task in self.task_config['tasks']]
106
+
107
+ for task_key in required_tasks:
108
+ if task_key not in classification:
109
+ return True
110
+
111
+ value = classification[task_key]
112
+ # Check for incomplete markers
113
+ if not value or value == "..." or value == "" or value is None:
114
+ return True
115
+
116
+ return False
117
+
118
+ def __len__(self):
119
+ return len(self.samples)
120
+
121
+ def __getitem__(self, idx):
122
+ sample = self.samples[idx]
123
+
124
+ # Load image
125
+ image = Image.open(sample['image_path']).convert('RGB')
126
+
127
+ # Process image only
128
+ inputs = self.processor(
129
+ images=image,
130
+ return_tensors="pt"
131
+ )
132
+
133
+ # Convert classifications to labels based on task config
134
+ classification = sample['classification']
135
+ labels = {}
136
+
137
+ for task in self.task_config['tasks']:
138
+ task_key = task['key']
139
+ if task['type'] == 'binary':
140
+ # Binary tasks: convert yes/no to 1/0
141
+ labels[task_key] = 1 if classification[task_key] == 'yes' else 0
142
+ elif task['type'] == 'multi_class':
143
+ # Multi-class tasks: convert to index
144
+ label_str = classification[task_key]
145
+ labels[task_key] = self.label_mappings[task_key].get(label_str, 0) # default to first class
146
+
147
+ return {
148
+ 'pixel_values': inputs['pixel_values'].squeeze(0),
149
+ 'labels': labels,
150
+ 'metadata': {
151
+ 'idx': sample['idx'],
152
+ 'caption': sample['caption'],
153
+ 'image_path': sample['image_path']
154
+ }
155
+ }
156
+
157
+ class MultiHeadSiglipClassifier(nn.Module):
158
+ """SigLIP-based multi-head classifier with configurable tasks"""
159
+ def __init__(self, task_config: Dict, model_name: str = CKPT):
160
+ super().__init__()
161
+
162
+ self.task_config = task_config
163
+ self.siglip = SiglipModel.from_pretrained(model_name)
164
+
165
+ # Freeze SigLIP parameters initially
166
+ for param in self.siglip.parameters():
167
+ param.requires_grad = False
168
+
169
+ # Create classification heads dynamically based on task config
170
+ hidden_size = self.siglip.config.vision_config.hidden_size
171
+ self.classification_heads = nn.ModuleDict()
172
+
173
+ for task in task_config['tasks']:
174
+ task_key = task['key']
175
+ num_classes = len(task['labels'])
176
+
177
+ # Create linear layer for this task
178
+ head = nn.Linear(hidden_size, num_classes)
179
+
180
+ # Initialize with zeros
181
+ head.weight.data.zero_()
182
+ head.bias.data.zero_()
183
+
184
+ self.classification_heads[task_key] = head
185
+
186
+ logger.info(f"Created {len(self.classification_heads)} classification heads")
187
+
188
+ def forward(self, pixel_values):
189
+ # Get SigLIP image embeddings only
190
+ combined_embeds = self.siglip.get_image_features(pixel_values=pixel_values)
191
+
192
+ # Apply all classification heads
193
+ outputs = {}
194
+ for task_key, head in self.classification_heads.items():
195
+ outputs[task_key] = head(combined_embeds)
196
+
197
+ return outputs
198
+
199
+ def calculate_accuracy(predictions, labels):
200
+ """Calculate accuracy for binary/multi-class predictions"""
201
+ pred_classes = torch.argmax(predictions, dim=1)
202
+ correct = (pred_classes == labels).float()
203
+ return correct.mean().item()
204
+
205
+ def plot_validation_accuracies(history, task_config, save_path='./checkpoints/validation_accuracies.png'):
206
+ """Create and save validation accuracy plots"""
207
+ tasks = [task['key'] for task in task_config['tasks']]
208
+ task_names = [task['name'] for task in task_config['tasks']]
209
+
210
+ # Calculate grid size
211
+ n_tasks = len(tasks)
212
+ n_cols = 3
213
+ n_rows = (n_tasks + n_cols - 1) // n_cols # Ceiling division
214
+
215
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=(18, 6 * n_rows))
216
+ fig.suptitle('Training Progress Dashboard', fontsize=16, fontweight='bold')
217
+
218
+ # Flatten axes for easier indexing
219
+ if n_rows == 1:
220
+ axes = [axes] if n_cols == 1 else axes
221
+ else:
222
+ axes = axes.flatten()
223
+
224
+ epochs = range(1, len(history['val_accuracy'][tasks[0]]) + 1)
225
+ colors = plt.cm.Set1(np.linspace(0, 1, n_tasks))
226
+
227
+ # Plot individual validation accuracies
228
+ for i, (task_key, task_name, color) in enumerate(zip(tasks, task_names, colors)):
229
+ if i < len(axes):
230
+ axes[i].plot(epochs, history['val_accuracy'][task_key],
231
+ label=task_name, marker='o', color=color, linewidth=2, markersize=4)
232
+ axes[i].set_xlabel('Epoch')
233
+ axes[i].set_ylabel('Validation Accuracy')
234
+ axes[i].set_title(f'{task_name} Validation Accuracy')
235
+ axes[i].grid(True, alpha=0.3)
236
+ axes[i].set_ylim(0, 1)
237
+
238
+ # Hide unused subplots
239
+ for i in range(n_tasks, len(axes)):
240
+ axes[i].set_visible(False)
241
+
242
+ plt.tight_layout()
243
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
244
+ plt.close()
245
+
246
+ logger.info(f"Validation accuracy plots saved to {save_path}")
247
+
248
+ # Calculate summary statistics
249
+ best_accs = [max(history['val_accuracy'][task]) for task in tasks]
250
+ final_accs = [history['val_accuracy'][task][-1] for task in tasks]
251
+
252
+ return best_accs, final_accs
253
+
254
+ def train_multi_head_classifier(data_dir: str, task_config_path: str = './task_config.json',
255
+ epochs: int = 30, batch_size: int = 4):
256
+ """Train the multi-head SigLIP v2 classifier"""
257
+ logger.info("Starting multi-head classifier training...")
258
+
259
+ # Load task configuration
260
+ task_config = load_task_config(task_config_path)
261
+
262
+ # Create checkpoints directory
263
+ checkpoint_dir = Path('./checkpoints')
264
+ checkpoint_dir.mkdir(exist_ok=True)
265
+ logger.info(f"Checkpoints will be saved to: {checkpoint_dir}")
266
+
267
+ # Save task config to checkpoints for inference
268
+ with open(checkpoint_dir / 'task_config.json', 'w') as f:
269
+ json.dump(task_config, f, indent=2)
270
+
271
+ # Load processor and model
272
+ processor = AutoProcessor.from_pretrained(CKPT)
273
+ model = MultiHeadSiglipClassifier(task_config, model_name=CKPT)
274
+
275
+ # Dataset and dataloader
276
+ dataset = MultiHeadDataset(data_dir, processor, task_config)
277
+ if len(dataset) == 0:
278
+ logger.error("No training data found!")
279
+ return
280
+
281
+ # Split dataset (simple train/val split)
282
+ train_size = int(0.8 * len(dataset))
283
+ val_size = len(dataset) - train_size
284
+ train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
285
+
286
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
287
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
288
+
289
+ # Setup training
290
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
291
+ logger.info(f"Using device: {device}")
292
+ model.to(device)
293
+
294
+ # Optimizer and loss functions
295
+ # Get model parameters that require gradients (only classification heads)
296
+ params = []
297
+ for name, param in model.named_parameters():
298
+ if param.requires_grad:
299
+ params.append(param)
300
+
301
+ optimizer = optim.AdamW(params, lr=1e-2)
302
+
303
+ # Linear cooldown LR scheduler
304
+ def linear_cooldown(epoch):
305
+ return max(0.1, 1.0 - (epoch / epochs))
306
+
307
+ scheduler = LambdaLR(optimizer, lr_lambda=linear_cooldown)
308
+ criterion = nn.CrossEntropyLoss()
309
+
310
+ # Initialize training history
311
+ history = {
312
+ 'train_loss': [],
313
+ 'val_loss': [],
314
+ 'learning_rates': [],
315
+ 'val_accuracy': {task['key']: [] for task in task_config['tasks']},
316
+ 'epoch_val_accuracy': []
317
+ }
318
+
319
+ # Training loop
320
+ for epoch in range(epochs):
321
+ # Training phase
322
+ model.train()
323
+ total_train_loss = 0
324
+
325
+ for batch_idx, batch in enumerate(train_loader):
326
+ optimizer.zero_grad()
327
+
328
+ # Move to device
329
+ pixel_values = batch['pixel_values'].to(device)
330
+
331
+ # Forward pass
332
+ outputs = model(pixel_values)
333
+
334
+ # Calculate losses for each task
335
+ losses = []
336
+ for task in task_config['tasks']:
337
+ task_key = task['key']
338
+ labels = batch['labels'][task_key].to(device)
339
+ loss = criterion(outputs[task_key], labels)
340
+ losses.append(loss)
341
+
342
+ # Total loss
343
+ total_batch_loss = sum(losses)
344
+ total_batch_loss.backward()
345
+ optimizer.step()
346
+
347
+ total_train_loss += total_batch_loss.item()
348
+
349
+ if batch_idx % 10 == 0:
350
+ logger.info(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {total_batch_loss.item():.4f}")
351
+
352
+ avg_train_loss = total_train_loss / len(train_loader)
353
+ history['train_loss'].append(avg_train_loss)
354
+
355
+ # Record learning rate
356
+ current_lr = optimizer.param_groups[0]['lr']
357
+ history['learning_rates'].append(current_lr)
358
+
359
+ # Validation phase
360
+ model.eval()
361
+ total_val_loss = 0
362
+ val_accuracies = {task['key']: [] for task in task_config['tasks']}
363
+
364
+ with torch.no_grad():
365
+ for batch in val_loader:
366
+ pixel_values = batch['pixel_values'].to(device)
367
+
368
+ outputs = model(pixel_values)
369
+
370
+ # Calculate validation losses and accuracies
371
+ losses = []
372
+ for task in task_config['tasks']:
373
+ task_key = task['key']
374
+ labels = batch['labels'][task_key].to(device)
375
+ loss = criterion(outputs[task_key], labels)
376
+ losses.append(loss)
377
+
378
+ # Calculate accuracy
379
+ acc = calculate_accuracy(outputs[task_key], labels)
380
+ val_accuracies[task_key].append(acc)
381
+
382
+ total_val_loss += sum(losses).item()
383
+
384
+ avg_val_loss = total_val_loss / len(val_loader)
385
+ history['val_loss'].append(avg_val_loss)
386
+
387
+ # Calculate average accuracies
388
+ epoch_accuracies = {}
389
+ for task in task_config['tasks']:
390
+ task_key = task['key']
391
+ avg_acc = np.mean(val_accuracies[task_key])
392
+ epoch_accuracies[task_key] = avg_acc
393
+ history['val_accuracy'][task_key].append(avg_acc)
394
+
395
+ history['epoch_val_accuracy'].append(epoch_accuracies.copy())
396
+
397
+ logger.info(f"Epoch {epoch+1}/{epochs}")
398
+ logger.info(f" Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
399
+ logger.info(f" Learning Rate: {current_lr:.6f}")
400
+ logger.info(f" Val Accuracies: {epoch_accuracies}")
401
+
402
+ # Step the learning rate scheduler
403
+ scheduler.step()
404
+
405
+ # Create comprehensive checkpoint
406
+ checkpoint = {
407
+ 'epoch': epochs,
408
+ 'model_state_dict': model.state_dict(),
409
+ 'optimizer_state_dict': optimizer.state_dict(),
410
+ 'scheduler_state_dict': scheduler.state_dict(),
411
+ 'history': history,
412
+ 'final_accuracies': epoch_accuracies,
413
+ 'task_config': task_config
414
+ }
415
+
416
+ # Save the trained model and checkpoint
417
+ torch.save(model.state_dict(), checkpoint_dir / 'multi_head_siglip2_classifier.pth')
418
+ torch.save(checkpoint, checkpoint_dir / 'training_checkpoint.pth')
419
+ logger.info(f"Model saved to {checkpoint_dir / 'multi_head_siglip2_classifier.pth'}")
420
+ logger.info(f"Full checkpoint saved to {checkpoint_dir / 'training_checkpoint.pth'}")
421
+
422
+ # Save processor for inference
423
+ processor.save_pretrained(checkpoint_dir / 'siglip2_processor')
424
+ logger.info(f"Processor saved to {checkpoint_dir / 'siglip2_processor'}")
425
+
426
+ # Save training history as JSON
427
+ with open(checkpoint_dir / 'training_history.json', 'w') as f:
428
+ json_history = {}
429
+ for key, value in history.items():
430
+ if key == 'val_accuracy':
431
+ json_history[key] = {task: [float(acc) for acc in accs] for task, accs in value.items()}
432
+ elif key == 'epoch_val_accuracy':
433
+ json_history[key] = [{task: float(acc) for task, acc in epoch.items()} for epoch in value]
434
+ else:
435
+ json_history[key] = [float(x) for x in value]
436
+ json.dump(json_history, f, indent=2)
437
+ logger.info(f"Training history saved to {checkpoint_dir / 'training_history.json'}")
438
+
439
+ # Generate and save validation accuracy plots
440
+ best_accs, final_accs = plot_validation_accuracies(history, task_config, checkpoint_dir / 'validation_accuracies.png')
441
+
442
+ # Save detailed validation accuracy summary
443
+ val_summary = {
444
+ 'best_accuracies': {
445
+ task['key']: float(max(history['val_accuracy'][task['key']]))
446
+ for task in task_config['tasks']
447
+ },
448
+ 'final_accuracies': {task: float(acc) for task, acc in epoch_accuracies.items()},
449
+ 'average_best_accuracy': float(np.mean(best_accs)),
450
+ 'average_final_accuracy': float(np.mean(final_accs)),
451
+ 'improvement_per_task': {
452
+ task['key']: float(history['val_accuracy'][task['key']][-1] - history['val_accuracy'][task['key']][0])
453
+ for task in task_config['tasks']
454
+ }
455
+ }
456
+
457
+ with open(checkpoint_dir / 'validation_summary.json', 'w') as f:
458
+ json.dump(val_summary, f, indent=2)
459
+ logger.info(f"Validation summary saved to {checkpoint_dir / 'validation_summary.json'}")
460
+
461
+ # Save final training summary
462
+ final_summary = {
463
+ "model_type": "SigLIP2 Multi-Head Classifier",
464
+ "training_samples": len(train_dataset),
465
+ "validation_samples": len(val_dataset),
466
+ "epochs": epochs,
467
+ "final_train_loss": avg_train_loss,
468
+ "final_val_loss": avg_val_loss,
469
+ "final_accuracies": epoch_accuracies,
470
+ "task_config": task_config,
471
+ "classification_heads": {
472
+ task['key']: f"{task['type']} - {task['description']}"
473
+ for task in task_config['tasks']
474
+ }
475
+ }
476
+
477
+ with open(checkpoint_dir / 'stage4_summary.json', 'w') as f:
478
+ json.dump(final_summary, f, indent=2)
479
+ logger.info(f"Stage 4 summary saved to {checkpoint_dir / 'stage4_summary.json'}")
480
+
481
+ # Log summary of saved artifacts
482
+ logger.info("="*60)
483
+ logger.info("TRAINING COMPLETE - ARTIFACTS SAVED:")
484
+ logger.info(f"📁 Checkpoint Directory: {checkpoint_dir}")
485
+ logger.info(f"🤖 Model Weights: multi_head_siglip2_classifier.pth")
486
+ logger.info(f"💾 Full Checkpoint: training_checkpoint.pth")
487
+ logger.info(f"🔧 Processor: siglip2_processor/")
488
+ logger.info(f"⚙️ Task Config: task_config.json")
489
+ logger.info(f"📊 Training History: training_history.json")
490
+ logger.info(f"📈 Validation Plots: validation_accuracies.png")
491
+ logger.info(f"📋 Validation Summary: validation_summary.json")
492
+ logger.info(f"📄 Stage Summary: stage4_summary.json")
493
+ logger.info("="*60)
494
+
495
+ def main():
496
+ """Main execution for Stage 4"""
497
+ logger.info("Starting Stage 4: SigLIP v2 Multi-Head Training...")
498
+
499
+ # Train classifier
500
+ train_multi_head_classifier('./data', epochs=10, batch_size=2)
501
+
502
+ logger.info("Stage 4 completed successfully!")
503
+ logger.info("🎉 Complete pipeline finished! Check ./checkpoints/ for all training artifacts.")
504
+
505
+ if __name__ == "__main__":
506
+ main()
streamlit_evaluation_app.py ADDED
@@ -0,0 +1,695 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Streamlit Data Viewer and Model Evaluation System
4
+ Interactive dashboard for exploring validation results with threshold filtering
5
+ """
6
+
7
+ import streamlit as st
8
+ import json
9
+ import pandas as pd
10
+ import numpy as np
11
+ from PIL import Image
12
+ import plotly.express as px
13
+ import plotly.graph_objects as go
14
+ from plotly.subplots import make_subplots
15
+ import os
16
+ from pathlib import Path
17
+ import subprocess
18
+ import sys
19
+ from rapidocr import RapidOCR
20
+ from matplotlib import pyplot as plt
21
+
22
+ # Page config
23
+ st.set_page_config(
24
+ page_title="Pseudoable Classifier Evaluation Dashboard",
25
+ page_icon="🔍",
26
+ layout="wide",
27
+ initial_sidebar_state="expanded"
28
+ )
29
+
30
+ # Custom CSS for better styling
31
+ st.markdown("""
32
+ <style>
33
+ .main-header {
34
+ font-size: 3rem;
35
+ font-weight: bold;
36
+ color: #1f77b4;
37
+ text-align: center;
38
+ margin-bottom: 2rem;
39
+ padding: 1rem;
40
+ background: linear-gradient(90deg, #f0f8ff, #e6f3ff);
41
+ border-radius: 10px;
42
+ }
43
+ .metric-card {
44
+ background-color: #f8f9fa;
45
+ padding: 1rem;
46
+ border-radius: 8px;
47
+ border-left: 4px solid #1f77b4;
48
+ margin: 0.5rem 0;
49
+ }
50
+ .filter-section {
51
+ background-color: #ffffff;
52
+ padding: 1.5rem;
53
+ border-radius: 10px;
54
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
55
+ margin-bottom: 2rem;
56
+ }
57
+ .image-container {
58
+ border: 2px solid #e6e6e6;
59
+ border-radius: 8px;
60
+ padding: 10px;
61
+ margin: 10px 0;
62
+ background-color: #fafafa;
63
+ }
64
+ .prediction-badge {
65
+ display: inline-block;
66
+ padding: 0.25rem 0.5rem;
67
+ border-radius: 15px;
68
+ font-size: 0.8rem;
69
+ font-weight: bold;
70
+ margin: 0.2rem;
71
+ }
72
+ .correct-prediction {
73
+ background-color: #d4edda;
74
+ color: #155724;
75
+ }
76
+ .incorrect-prediction {
77
+ background-color: #f8d7da;
78
+ color: #721c24;
79
+ }
80
+ </style>
81
+ """, unsafe_allow_html=True)
82
+
83
+ @st.cache_data
84
+ def load_task_config(config_path: str = './task_config.json'):
85
+ """Load task configuration from JSON file"""
86
+ if not os.path.exists(config_path):
87
+ # Try to load from checkpoints directory
88
+ checkpoint_config = './checkpoints/task_config.json'
89
+ if os.path.exists(checkpoint_config):
90
+ config_path = checkpoint_config
91
+ else:
92
+ return None
93
+
94
+ with open(config_path, 'r') as f:
95
+ config = json.load(f)
96
+ return config
97
+
98
+ @st.cache_data
99
+ def load_validation_results(file_path: str = './validation_results.json'):
100
+ """Load validation results from JSON file"""
101
+ if not os.path.exists(file_path):
102
+ return None
103
+
104
+ with open(file_path, 'r') as f:
105
+ data = json.load(f)
106
+ return data
107
+
108
+ @st.cache_resource
109
+ def get_ocr_engine():
110
+ """Initialize and cache OCR engine"""
111
+ return RapidOCR()
112
+
113
+ @st.cache_data
114
+ def extract_text_from_image(image_path: str):
115
+ """Extract text from image using OCR"""
116
+ try:
117
+ engine = get_ocr_engine()
118
+ result = engine(image_path)
119
+
120
+ # Handle new RapidOCR output format
121
+ if result and hasattr(result, 'txts') and result.txts:
122
+ texts = result.txts
123
+ return {
124
+ 'text': ' '.join(texts) if texts else '',
125
+ 'num_text_blocks': len(texts),
126
+ 'has_text': len(texts) > 0
127
+ }
128
+ elif result and isinstance(result, (list, tuple)) and len(result) > 0:
129
+ # Fallback for older format
130
+ texts = []
131
+ for item in result:
132
+ if len(item) >= 2:
133
+ texts.append(item[1])
134
+
135
+ return {
136
+ 'text': ' '.join(texts) if texts else '',
137
+ 'num_text_blocks': len(texts),
138
+ 'has_text': len(texts) > 0
139
+ }
140
+ else:
141
+ return {
142
+ 'text': '',
143
+ 'num_text_blocks': 0,
144
+ 'has_text': False
145
+ }
146
+ except Exception as e:
147
+ return {
148
+ 'text': f'OCR Error: {str(e)}',
149
+ 'num_text_blocks': 0,
150
+ 'has_text': False
151
+ }
152
+
153
+ @st.cache_data
154
+ def process_validation_data(validation_data, task_config):
155
+ """Process validation data into DataFrame for easier filtering"""
156
+ if not validation_data or not task_config:
157
+ return None
158
+
159
+ rows = []
160
+ tasks = {task['key']: task for task in task_config['tasks']}
161
+
162
+ for result in validation_data['results']:
163
+ row = {
164
+ 'idx': result['idx'],
165
+ 'caption': result['caption'],
166
+ 'image_path': result['image_path'],
167
+ 'url': result['url'],
168
+ 'hash': result['hash']
169
+ }
170
+
171
+ # Ground truth and predictions
172
+ gt = result['ground_truth']
173
+ pred = result['predictions']
174
+
175
+ # Process each task dynamically
176
+ for task_key, task_info in tasks.items():
177
+ # Ground truth
178
+ row[f'gt_{task_key}'] = gt[task_key]
179
+
180
+ # Predictions
181
+ row[f'pred_{task_key}'] = pred[f'{task_key}_prediction']
182
+ row[f'conf_{task_key}'] = pred[f'{task_key}_confidence']
183
+
184
+ # For binary tasks, also include probability for 'yes'
185
+ if task_info['type'] == 'binary':
186
+ row[f'prob_{task_key}_yes'] = pred.get(f'{task_key}_prob_yes', 0.5)
187
+
188
+ # Correctness
189
+ row[f'correct_{task_key}'] = gt[task_key] == pred[f'{task_key}_prediction']
190
+
191
+ rows.append(row)
192
+
193
+ return pd.DataFrame(rows)
194
+
195
+ def run_validation_if_needed():
196
+ """Run validation if results don't exist"""
197
+ if not os.path.exists('./validation_results.json'):
198
+ st.warning("Validation results not found. Running validation...")
199
+
200
+ # Check if model exists
201
+ if not os.path.exists('./checkpoints/multi_head_siglip2_classifier.pth'):
202
+ st.error("❌ Trained model not found! Please run the training pipeline first.")
203
+ st.code("python stage_4.py")
204
+ return False
205
+
206
+ # Run validation
207
+ with st.spinner("Running model on validation set... This may take a few minutes."):
208
+ try:
209
+ result = subprocess.run([sys.executable, 'validation_runner.py'],
210
+ capture_output=True, text=True)
211
+ if result.returncode == 0:
212
+ st.success("✅ Validation completed successfully!")
213
+ st.rerun()
214
+ else:
215
+ st.error(f"❌ Validation failed: {result.stderr}")
216
+ return False
217
+ except Exception as e:
218
+ st.error(f"❌ Error running validation: {e}")
219
+ return False
220
+
221
+ return True
222
+
223
+ def create_overview_metrics(df, validation_data, task_config):
224
+ """Create overview metrics section"""
225
+ st.markdown("## 📊 Overview Metrics")
226
+
227
+ tasks = [task['key'] for task in task_config['tasks']]
228
+
229
+ # Basic stats
230
+ col1, col2, col3, col4 = st.columns(4)
231
+
232
+ with col1:
233
+ st.metric("Total Samples", len(df))
234
+
235
+ with col2:
236
+ avg_confidence = np.mean([df[f'conf_{task}'].mean() for task in tasks])
237
+ st.metric("Avg Confidence", f"{avg_confidence:.3f}")
238
+
239
+ with col3:
240
+ overall_accuracy = np.mean([df[f'correct_{task}'].mean() for task in tasks])
241
+ st.metric("Overall Accuracy", f"{overall_accuracy:.3f}")
242
+
243
+ with col4:
244
+ if validation_data and 'metadata' in validation_data:
245
+ model_accuracies = validation_data['metadata']['validation_accuracies']
246
+ model_avg = np.mean(list(model_accuracies.values()))
247
+ st.metric("Model Accuracy", f"{model_avg:.3f}")
248
+
249
+ # Detailed accuracies
250
+ st.markdown("### 🎯 Accuracy per Classification Task")
251
+
252
+ # Create dynamic columns based on number of tasks
253
+ n_tasks = len(tasks)
254
+ n_cols = min(5, n_tasks) # Max 5 columns
255
+ acc_cols = st.columns(n_cols)
256
+
257
+ for i, task in enumerate(tasks):
258
+ task_info = next(t for t in task_config['tasks'] if t['key'] == task)
259
+ with acc_cols[i % n_cols]:
260
+ accuracy = df[f'correct_{task}'].mean()
261
+ st.metric(task_info['name'], f"{accuracy:.3f}")
262
+
263
+ def create_confidence_distribution_plot(df, task_config):
264
+ """Create confidence distribution plots"""
265
+ tasks = [task['key'] for task in task_config['tasks']]
266
+ task_names = [task['name'] for task in task_config['tasks']]
267
+
268
+ n_tasks = len(tasks)
269
+ n_cols = 3
270
+ n_rows = (n_tasks + n_cols - 1) // n_cols
271
+
272
+ fig = make_subplots(
273
+ rows=n_rows, cols=n_cols,
274
+ subplot_titles=task_names,
275
+ specs=[[{"secondary_y": False} for _ in range(n_cols)] for _ in range(n_rows)]
276
+ )
277
+
278
+ colors = plt.cm.Set1(np.linspace(0, 1, n_tasks))
279
+
280
+ for i, (task_key, color) in enumerate(zip(tasks, colors)):
281
+ row = (i // n_cols) + 1
282
+ col = (i % n_cols) + 1
283
+
284
+ fig.add_trace(
285
+ go.Histogram(
286
+ x=df[f'conf_{task_key}'],
287
+ nbinsx=20,
288
+ name=f'{task_key}',
289
+ marker_color=f'rgba({color[0]*255:.0f},{color[1]*255:.0f},{color[2]*255:.0f},0.7)',
290
+ opacity=0.7
291
+ ),
292
+ row=row, col=col
293
+ )
294
+
295
+ fig.update_layout(
296
+ title="Confidence Score Distributions",
297
+ showlegend=False,
298
+ height=200 * n_rows + 100
299
+ )
300
+
301
+ return fig
302
+
303
+ def apply_filters(df, task_config):
304
+ """Apply user-defined filters to the dataframe"""
305
+ st.markdown("## 🔍 Filter Data")
306
+
307
+ tasks = {task['key']: task for task in task_config['tasks']}
308
+
309
+ # Create filter sidebar
310
+ with st.sidebar:
311
+ st.markdown("### Task Confidence Filters")
312
+
313
+ # Confidence thresholds for each task
314
+ confidence_filters = {}
315
+ for task_key, task_info in tasks.items():
316
+ if task_info['type'] == 'multi_class':
317
+ # Only show confidence threshold for multi-class tasks
318
+ confidence_filters[task_key] = st.slider(
319
+ f"{task_info['name']} Confidence",
320
+ 0.0, 1.0, 0.5, 0.01,
321
+ key=f"conf_{task_key}"
322
+ )
323
+
324
+ st.markdown("### Content Filters")
325
+
326
+ # Category filters (for multi-class tasks)
327
+ category_filters = {}
328
+ for task_key, task_info in tasks.items():
329
+ if task_info['type'] == 'multi_class':
330
+ available_values = df[f'gt_{task_key}'].unique().tolist()
331
+ selected_values = st.multiselect(
332
+ f"Ground Truth {task_info['name']}",
333
+ available_values,
334
+ default=available_values,
335
+ key=f"gt_{task_key}_filter"
336
+ )
337
+ category_filters[task_key] = selected_values
338
+
339
+ # Binary prediction filters
340
+ st.markdown("**Filter by Predictions:**")
341
+ prediction_filters = {}
342
+ for task_key, task_info in tasks.items():
343
+ if task_info['type'] == 'binary':
344
+ filter_value = st.selectbox(
345
+ f"{task_info['name']}:",
346
+ ["All", "Yes only", "No only"],
347
+ key=f"pred_{task_key}_filter"
348
+ )
349
+ prediction_filters[task_key] = filter_value
350
+
351
+ # Correctness filter
352
+ st.markdown("**Filter by Correctness:**")
353
+ correctness_filter = st.selectbox(
354
+ "Show only:",
355
+ ["All predictions", "Correct predictions", "Incorrect predictions"]
356
+ )
357
+
358
+ # OCR filters (if screenshot task exists)
359
+ has_screenshot_task = any(task['key'] == 'is_screenshot_with_text' for task in task_config['tasks'])
360
+ if has_screenshot_task:
361
+ st.markdown("**Filter by Text Content:**")
362
+ ocr_filter = st.selectbox(
363
+ "Text Content:",
364
+ ["All images", "Images with text", "Images without text"],
365
+ key="ocr_filter"
366
+ )
367
+ enable_ocr = st.checkbox("Enable OCR text extraction", value=True)
368
+ else:
369
+ ocr_filter = "All images"
370
+ enable_ocr = False
371
+
372
+ # Apply filters
373
+ filtered_df = df.copy()
374
+
375
+ # Confidence filters
376
+ for task_key, threshold in confidence_filters.items():
377
+ filtered_df = filtered_df[filtered_df[f'conf_{task_key}'] >= threshold]
378
+
379
+ # Category filters
380
+ for task_key, selected_values in category_filters.items():
381
+ filtered_df = filtered_df[filtered_df[f'gt_{task_key}'].isin(selected_values)]
382
+
383
+ # Binary prediction filters
384
+ for task_key, filter_value in prediction_filters.items():
385
+ if filter_value == "Yes only":
386
+ filtered_df = filtered_df[filtered_df[f'pred_{task_key}'] == 'yes']
387
+ elif filter_value == "No only":
388
+ filtered_df = filtered_df[filtered_df[f'pred_{task_key}'] == 'no']
389
+
390
+ # Correctness filter
391
+ if correctness_filter == "Correct predictions":
392
+ correct_mask = True
393
+ for task_key in tasks.keys():
394
+ correct_mask = correct_mask & filtered_df[f'correct_{task_key}']
395
+ filtered_df = filtered_df[correct_mask]
396
+ elif correctness_filter == "Incorrect predictions":
397
+ correct_mask = True
398
+ for task_key in tasks.keys():
399
+ correct_mask = correct_mask & filtered_df[f'correct_{task_key}']
400
+ filtered_df = filtered_df[~correct_mask]
401
+
402
+ # Show filter results
403
+ st.info(f"Filtered to {len(filtered_df)} samples (from {len(df)} total)")
404
+
405
+ return filtered_df, ocr_filter, enable_ocr
406
+
407
+ def display_sample_images(df, task_config, ocr_filter="All images", enable_ocr=True):
408
+ """Display sample images with predictions and ground truth"""
409
+ st.markdown("## 🖼️ Sample Images")
410
+
411
+ if len(df) == 0:
412
+ st.warning("No images match the current filters.")
413
+ return
414
+
415
+ tasks = {task['key']: task for task in task_config['tasks']}
416
+
417
+ # Add controls for image display
418
+ col1, col2, col3 = st.columns([2, 1, 1])
419
+
420
+ with col1:
421
+ max_images = st.slider(
422
+ "Number of images to display",
423
+ min_value=10,
424
+ max_value=min(200, len(df)),
425
+ value=min(50, len(df)),
426
+ step=10
427
+ )
428
+
429
+ with col2:
430
+ sort_by = st.selectbox(
431
+ "Sort by:",
432
+ ["Original order", "Confidence (low to high)", "Confidence (high to low)"]
433
+ )
434
+
435
+ with col3:
436
+ cols_per_row = st.selectbox("Images per row:", [2, 3, 4], index=1)
437
+
438
+ # Sort dataframe if requested
439
+ task_keys = list(tasks.keys())
440
+ if sort_by == "Confidence (low to high)":
441
+ avg_conf = sum(df[f'conf_{task}'] for task in task_keys) / len(task_keys)
442
+ display_df = df.iloc[avg_conf.argsort()].head(max_images)
443
+ elif sort_by == "Confidence (high to low)":
444
+ avg_conf = sum(df[f'conf_{task}'] for task in task_keys) / len(task_keys)
445
+ display_df = df.iloc[avg_conf.argsort()[::-1]].head(max_images)
446
+ else:
447
+ display_df = df.head(max_images)
448
+
449
+ # Apply OCR filtering if needed
450
+ if enable_ocr and ocr_filter != "All images":
451
+ st.info("🔍 Applying OCR filtering... This may take a moment for many images.")
452
+
453
+ ocr_results = []
454
+ progress_bar = st.progress(0)
455
+
456
+ for idx, (_, row) in enumerate(display_df.iterrows()):
457
+ if os.path.exists(row['image_path']):
458
+ ocr_result = extract_text_from_image(row['image_path'])
459
+ ocr_results.append(ocr_result['has_text'])
460
+ else:
461
+ ocr_results.append(False)
462
+
463
+ progress_bar.progress((idx + 1) / len(display_df))
464
+
465
+ progress_bar.empty()
466
+
467
+ # Filter based on OCR results
468
+ if ocr_filter == "Images with text":
469
+ mask = ocr_results
470
+ else: # "Images without text"
471
+ mask = [not has_text for has_text in ocr_results]
472
+
473
+ display_df = display_df[mask].reset_index(drop=True)
474
+ st.success(f"OCR filtering complete. Found {len(display_df)} images matching criteria.")
475
+
476
+ # Display images
477
+ for i in range(0, len(display_df), cols_per_row):
478
+ cols = st.columns(cols_per_row)
479
+
480
+ for j in range(cols_per_row):
481
+ if i + j < len(display_df):
482
+ row = display_df.iloc[i + j]
483
+
484
+ with cols[j]:
485
+ # Load and display image
486
+ try:
487
+ if os.path.exists(row['image_path']):
488
+ img = Image.open(row['image_path'])
489
+ st.image(img, caption=f"Sample {row['idx']}", use_column_width=True)
490
+ else:
491
+ st.error(f"Image not found: {row['image_path']}")
492
+ continue
493
+ except Exception as e:
494
+ st.error(f"Error loading image: {e}")
495
+ continue
496
+
497
+ # Caption
498
+ st.markdown(f"**Caption:** {row['caption'][:100]}...")
499
+
500
+ # OCR Text Extraction
501
+ if enable_ocr and 'is_screenshot_with_text' in tasks:
502
+ with st.expander("🔍 Extracted Text (OCR)", expanded=False):
503
+ ocr_result = extract_text_from_image(row['image_path'])
504
+ if ocr_result['has_text']:
505
+ st.markdown(f"**Text Blocks Found:** {ocr_result['num_text_blocks']}")
506
+ st.text_area(
507
+ "Extracted Text:",
508
+ value=ocr_result['text'],
509
+ height=100,
510
+ key=f"ocr_text_{row['idx']}",
511
+ help="Text extracted from the image using OCR"
512
+ )
513
+
514
+ text_length = len(ocr_result['text'])
515
+ word_count = len(ocr_result['text'].split())
516
+ st.caption(f"📊 Text stats: {text_length} chars, {word_count} words")
517
+
518
+ if row['pred_is_screenshot_with_text'] == 'yes':
519
+ st.success("✅ Screenshot prediction correlates with text presence")
520
+ elif ocr_result['num_text_blocks'] > 5:
521
+ st.warning("⚠️ High text content but not predicted as screenshot")
522
+ else:
523
+ st.info("No text detected in this image")
524
+ if row['pred_is_screenshot_with_text'] == 'yes':
525
+ st.warning("⚠️ Predicted as screenshot but no text found")
526
+
527
+ # Predictions vs Ground Truth
528
+ st.markdown("**Predictions vs Ground Truth:**")
529
+
530
+ # Display all tasks dynamically
531
+ for task_key, task_info in tasks.items():
532
+ pred_val = row[f'pred_{task_key}']
533
+ gt_val = row[f'gt_{task_key}']
534
+ conf_val = row[f'conf_{task_key}']
535
+ correct = pred_val == gt_val
536
+
537
+ badge_class = "correct-prediction" if correct else "incorrect-prediction"
538
+ st.markdown(f"""
539
+ <div class="prediction-badge {badge_class}">
540
+ {task_info['name']}: {pred_val} | GT: {gt_val} | Conf: {conf_val:.3f}
541
+ </div>
542
+ """, unsafe_allow_html=True)
543
+
544
+ st.markdown("---")
545
+
546
+ if len(df) > max_images:
547
+ st.info(f"Showing {max_images} of {len(df)} filtered images. Use the slider above to show more images.")
548
+
549
+ def create_confusion_matrices(df, task_config):
550
+ """Create confusion matrices for each classification task"""
551
+ st.markdown("## 📊 Model Performance Analysis")
552
+
553
+ tasks = {task['key']: task for task in task_config['tasks']}
554
+ binary_tasks = [t for t in tasks.values() if t['type'] == 'binary']
555
+ multi_class_tasks = [t for t in tasks.values() if t['type'] == 'multi_class']
556
+
557
+ tab1, tab2, tab3 = st.tabs(["Confusion Matrices", "Confidence Analysis", "Task Performance"])
558
+
559
+ with tab1:
560
+ # Binary classification confusion matrices
561
+ if binary_tasks:
562
+ st.markdown("### Binary Classification Tasks")
563
+ n_binary = len(binary_tasks)
564
+ n_cols = min(2, n_binary)
565
+
566
+ for i in range(0, n_binary, n_cols):
567
+ cols = st.columns(n_cols)
568
+ for j in range(n_cols):
569
+ if i + j < n_binary:
570
+ task = binary_tasks[i + j]
571
+ task_key = task['key']
572
+
573
+ with cols[j]:
574
+ confusion_data = pd.crosstab(
575
+ df[f'gt_{task_key}'],
576
+ df[f'pred_{task_key}'],
577
+ margins=True
578
+ )
579
+ st.markdown(f"**{task['name']} Confusion Matrix**")
580
+ st.dataframe(confusion_data, use_container_width=True)
581
+
582
+ # Multi-class confusion matrices
583
+ if multi_class_tasks:
584
+ st.markdown("### Multi-class Classification Tasks")
585
+ for task in multi_class_tasks:
586
+ task_key = task['key']
587
+ st.markdown(f"**{task['name']} Confusion Matrix**")
588
+ confusion_data = pd.crosstab(
589
+ df[f'gt_{task_key}'],
590
+ df[f'pred_{task_key}'],
591
+ margins=True
592
+ )
593
+ st.dataframe(confusion_data, use_container_width=True)
594
+
595
+ with tab2:
596
+ # Confidence analysis plots
597
+ fig1 = create_confidence_distribution_plot(df, task_config)
598
+ st.plotly_chart(fig1, use_container_width=True)
599
+
600
+ with tab3:
601
+ # Task-wise performance
602
+ st.markdown("**Performance by Task**")
603
+
604
+ performance_data = []
605
+ for task_key, task_info in tasks.items():
606
+ accuracy = df[f'correct_{task_key}'].mean()
607
+ confidence = df[f'conf_{task_key}'].mean()
608
+
609
+ performance_data.append({
610
+ 'Task': task_info['name'],
611
+ 'Type': task_info['type'],
612
+ 'Accuracy': accuracy,
613
+ 'Avg Confidence': confidence
614
+ })
615
+
616
+ performance_df = pd.DataFrame(performance_data)
617
+ st.dataframe(performance_df, use_container_width=True)
618
+
619
+ # Performance visualization
620
+ fig = px.scatter(performance_df, x='Avg Confidence', y='Accuracy',
621
+ color='Type', text='Task',
622
+ title="Task Performance: Accuracy vs Confidence")
623
+ fig.update_traces(textposition="top center")
624
+ st.plotly_chart(fig, use_container_width=True)
625
+
626
+ def main():
627
+ """Main Streamlit application"""
628
+ # Header
629
+ st.markdown('<div class="main-header">🔍 Pseudoable Classifier Evaluation Dashboard</div>',
630
+ unsafe_allow_html=True)
631
+
632
+ # Load task configuration
633
+ task_config = load_task_config()
634
+ if not task_config:
635
+ st.error("❌ Could not load task configuration. Please ensure task_config.json exists.")
636
+ st.info("Expected location: ./task_config.json or ./checkpoints/task_config.json")
637
+ return
638
+
639
+ st.success(f"✅ Loaded task configuration with {len(task_config['tasks'])} tasks")
640
+
641
+ # Display task information
642
+ with st.expander("📋 Task Configuration", expanded=False):
643
+ for task in task_config['tasks']:
644
+ st.markdown(f"**{task['name']}** ({task['type']})")
645
+ st.markdown(f"- *Description:* {task['description']}")
646
+ st.markdown(f"- *Labels:* {', '.join(task['labels'])}")
647
+ st.markdown("---")
648
+
649
+ # Check and run validation if needed
650
+ if not run_validation_if_needed():
651
+ return
652
+
653
+ # Load validation results
654
+ validation_data = load_validation_results()
655
+ if not validation_data:
656
+ st.error("❌ Could not load validation results. Please check if validation_results.json exists.")
657
+ return
658
+
659
+ # Process data
660
+ df = process_validation_data(validation_data, task_config)
661
+ if df is None or len(df) == 0:
662
+ st.error("❌ No validation data found or data processing failed.")
663
+ return
664
+
665
+ # Show basic info
666
+ st.success(f"✅ Loaded {len(df)} validation samples successfully!")
667
+
668
+ # Overview metrics
669
+ create_overview_metrics(df, validation_data, task_config)
670
+
671
+ # Apply filters
672
+ filtered_df, ocr_filter, enable_ocr = apply_filters(df, task_config)
673
+
674
+ # Display results
675
+ if len(filtered_df) > 0:
676
+ # Performance analysis
677
+ create_confusion_matrices(filtered_df, task_config)
678
+
679
+ # Sample images
680
+ display_sample_images(filtered_df, task_config, ocr_filter, enable_ocr)
681
+ else:
682
+ st.warning("⚠️ No samples match the current filter criteria. Please adjust your filters.")
683
+
684
+ # Footer
685
+ st.markdown("---")
686
+ st.markdown("**📝 Instructions:**")
687
+ st.markdown("1. Use the sidebar to filter by task confidence and prediction classes")
688
+ st.markdown("2. Filter images by text content using OCR (if screenshot detection task is configured)")
689
+ st.markdown("3. Adjust the number of images to display and sorting order")
690
+ st.markdown("4. View model performance metrics and confusion matrices")
691
+ st.markdown("5. Browse sample images with predictions vs ground truth")
692
+ st.markdown("6. Green badges indicate correct predictions, red badges indicate incorrect predictions")
693
+
694
+ if __name__ == "__main__":
695
+ main()
upload_to_hf.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Upload trained multi-head SigLIP2 classifier to Hugging Face Hub (private).
4
+
5
+ This script will create/update the repo `fal/multihead_cls` and push:
6
+ - model weights: checkpoints/multi_head_siglip2_classifier.pth
7
+ - full training checkpoint: checkpoints/training_checkpoint.pth (optional)
8
+ - processor folder: checkpoints/siglip2_processor/
9
+ - README.md with usage
10
+
11
+ Auth: Set HUGGINGFACE_TOKEN environment variable or run `huggingface-cli login`.
12
+ """
13
+
14
+ import os
15
+ import shutil
16
+ from pathlib import Path
17
+ from typing import Optional
18
+
19
+ from huggingface_hub import HfApi, HfFolder, create_repo, upload_folder, upload_file
20
+
21
+
22
+ REPO_ID = "fal/multihead_cls"
23
+
24
+
25
+ def ensure_logged_in() -> Optional[str]:
26
+ token = os.getenv("HUGGINGFACE_TOKEN") or HfFolder.get_token()
27
+ if not token:
28
+ raise RuntimeError(
29
+ "No Hugging Face token found. Set HUGGINGFACE_TOKEN or run `huggingface-cli login`."
30
+ )
31
+ return token
32
+
33
+
34
+ def prepare_staging_dir() -> Path:
35
+ root = Path(__file__).parent
36
+ ckpt_dir = root / "checkpoints"
37
+ if not ckpt_dir.exists():
38
+ raise FileNotFoundError("checkpoints/ directory not found. Train the model first.")
39
+
40
+ required = [
41
+ ckpt_dir / "multi_head_siglip2_classifier.pth",
42
+ ckpt_dir / "siglip2_processor",
43
+ ]
44
+ for path in required:
45
+ if not path.exists():
46
+ raise FileNotFoundError(f"Missing required artifact: {path}")
47
+
48
+ # Check for task_config.json in checkpoints or root directory
49
+ task_config_path = ckpt_dir / "task_config.json"
50
+ if not task_config_path.exists():
51
+ task_config_path = root / "task_config.json"
52
+ if not task_config_path.exists():
53
+ raise FileNotFoundError("Missing required artifact: task_config.json (checked both checkpoints/ and root directory)")
54
+
55
+ staging = root / "hf_export"
56
+ if staging.exists():
57
+ shutil.rmtree(staging)
58
+ staging.mkdir(parents=True)
59
+
60
+ # Copy artifacts
61
+ shutil.copy2(ckpt_dir / "multi_head_siglip2_classifier.pth", staging / "model.pth")
62
+ shutil.copy2(task_config_path, staging / "task_config.json")
63
+
64
+ # Optional: training checkpoint and other metadata
65
+ train_ckpt = ckpt_dir / "training_checkpoint.pth"
66
+ if train_ckpt.exists():
67
+ shutil.copy2(train_ckpt, staging / "training_checkpoint.pth")
68
+
69
+ # Optional: training history and validation summary
70
+ for optional_file in ["training_history.json", "validation_summary.json", "stage4_summary.json"]:
71
+ optional_path = ckpt_dir / optional_file
72
+ if optional_path.exists():
73
+ shutil.copy2(optional_path, staging / optional_file)
74
+
75
+ # Processor
76
+ shutil.copytree(ckpt_dir / "siglip2_processor", staging / "processor")
77
+
78
+ # Add example and README if present
79
+ readme_src = root / "README.md"
80
+ if readme_src.exists():
81
+ shutil.copy2(readme_src, staging / "README.md")
82
+ example_src = root / "example.py"
83
+ if example_src.exists():
84
+ shutil.copy2(example_src, staging / "example.py")
85
+
86
+ return staging
87
+
88
+
89
+ def upload_to_hub(private: bool = True) -> None:
90
+ token = ensure_logged_in()
91
+ api = HfApi(token=token)
92
+
93
+ create_repo(REPO_ID, private=private, repo_type="model", exist_ok=True, token=token)
94
+
95
+ staging = prepare_staging_dir()
96
+
97
+ # Upload all files in staging
98
+ upload_folder(
99
+ folder_path=str(staging),
100
+ repo_id=REPO_ID,
101
+ repo_type="model",
102
+ commit_message="Upload multi-head SigLIP2 classifier with dynamic task configuration",
103
+ token=token,
104
+ )
105
+
106
+ print(f"Uploaded to https://huggingface.co/{REPO_ID} (private={private})")
107
+
108
+
109
+ if __name__ == "__main__":
110
+ upload_to_hub(private=True)
111
+
112
+
validation_runner.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Validation Runner: Runs trained model on validation set and saves predictions
4
+ """
5
+
6
+ import os
7
+ import json
8
+ import torch
9
+ import numpy as np
10
+ from PIL import Image
11
+ from pathlib import Path
12
+ import logging
13
+ from transformers import AutoProcessor
14
+ from stage_4 import MultiHeadSiglipClassifier, CKPT, load_task_config
15
+ import pandas as pd
16
+ from tqdm import tqdm
17
+
18
+ # Set up logging
19
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
20
+ logger = logging.getLogger(__name__)
21
+
22
+ def _is_incomplete_classification(classification: dict, task_config: dict) -> bool:
23
+ """Check if classification contains incomplete data (empty or '...' values)"""
24
+ if not task_config or 'tasks' not in task_config:
25
+ return True
26
+
27
+ required_tasks = [task['key'] for task in task_config['tasks']]
28
+
29
+ for task_key in required_tasks:
30
+ if task_key not in classification:
31
+ return True
32
+
33
+ value = classification[task_key]
34
+ # Check for incomplete markers
35
+ if not value or value == "..." or value == "" or value is None:
36
+ return True
37
+
38
+ return False
39
+
40
+ def load_trained_model(checkpoint_dir: str = './checkpoints'):
41
+ """Load the trained model and processor"""
42
+ checkpoint_path = Path(checkpoint_dir)
43
+
44
+ # Load task configuration
45
+ task_config_path = checkpoint_path / 'task_config.json'
46
+ if not task_config_path.exists():
47
+ # Fallback to root directory
48
+ task_config_path = './task_config.json'
49
+
50
+ task_config = load_task_config(str(task_config_path))
51
+
52
+ # Load processor
53
+ processor = AutoProcessor.from_pretrained(CKPT)
54
+
55
+ # Load model with task config
56
+ model = MultiHeadSiglipClassifier(task_config)
57
+ model_state = torch.load(checkpoint_path / 'multi_head_siglip2_classifier.pth', map_location='cpu')
58
+ model.load_state_dict(model_state)
59
+
60
+ # Set to evaluation mode
61
+ model.eval()
62
+
63
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
64
+ model.to(device)
65
+
66
+ logger.info(f"Model loaded successfully on device: {device}")
67
+ return model, processor, device, task_config
68
+
69
+ def load_validation_data(data_dir: str = './data', task_config: dict = None):
70
+ """Load validation samples from stage 2 metadata files"""
71
+ data_path = Path(data_dir)
72
+
73
+ # Load from stage 2 metadata files
74
+ metadata_dir = data_path / 'metadata'
75
+ if not metadata_dir.exists():
76
+ logger.error("Metadata directory not found. Run stages 1 and 2 first.")
77
+ return []
78
+
79
+ metadata_files = list(metadata_dir.glob('meta_*_stage2.json'))
80
+ if not metadata_files:
81
+ logger.error("No stage 2 metadata files found. Run stage 2 first.")
82
+ return []
83
+
84
+ samples = []
85
+ skipped_incomplete = 0
86
+
87
+ for meta_file in tqdm(metadata_files, desc="Loading validation data"):
88
+ try:
89
+ with open(meta_file, 'r') as f:
90
+ metadata = json.load(f)
91
+
92
+ # Check if classification is complete
93
+ if not metadata.get('stage2_complete', False):
94
+ logger.warning(f"Skipping {meta_file} - classification not complete")
95
+ skipped_incomplete += 1
96
+ continue
97
+
98
+ # Check if classification contains incomplete data
99
+ classification = metadata.get('classification', {})
100
+ if not classification or _is_incomplete_classification(classification, task_config):
101
+ logger.warning(f"Skipping {meta_file} - incomplete classification data")
102
+ skipped_incomplete += 1
103
+ continue
104
+
105
+ # Check if image exists
106
+ image_path = metadata['image_path']
107
+ if not os.path.exists(image_path):
108
+ logger.warning(f"Image not found: {image_path}")
109
+ skipped_incomplete += 1
110
+ continue
111
+
112
+ samples.append({
113
+ 'idx': metadata['idx'],
114
+ 'image_path': metadata['image_path'],
115
+ 'caption': metadata['caption'],
116
+ 'url': metadata['url'],
117
+ 'hash': metadata['hash'],
118
+ 'ground_truth': metadata['classification']
119
+ })
120
+
121
+ except Exception as e:
122
+ logger.warning(f"Error loading {meta_file}: {e}")
123
+ skipped_incomplete += 1
124
+
125
+ if skipped_incomplete > 0:
126
+ logger.warning(f"Skipped {skipped_incomplete} incomplete samples")
127
+ logger.info(f"Loaded {len(samples)} valid validation samples")
128
+ return samples
129
+
130
+ def predict_batch(model, processor, images, device, task_config, batch_size=8):
131
+ """Run predictions on a batch of images"""
132
+ predictions = []
133
+ tasks = {task['key']: task for task in task_config['tasks']}
134
+
135
+ for i in range(0, len(images), batch_size):
136
+ batch_images = images[i:i+batch_size]
137
+
138
+ # Process images
139
+ inputs = processor(images=batch_images, return_tensors="pt")
140
+ pixel_values = inputs['pixel_values'].to(device)
141
+
142
+ with torch.no_grad():
143
+ outputs = model(pixel_values)
144
+
145
+ # Convert outputs to probabilities and predictions
146
+ batch_preds = []
147
+ for j in range(len(batch_images)):
148
+ pred = {}
149
+
150
+ # Process each task dynamically
151
+ for task_key, task_info in tasks.items():
152
+ logits = outputs[task_key][j]
153
+ probs = torch.softmax(logits, dim=0)
154
+ pred_class = torch.argmax(logits).item()
155
+ confidence = probs[pred_class].item()
156
+
157
+ if task_info['type'] == 'binary':
158
+ # Binary classification
159
+ pred[f'{task_key}_prediction'] = 'yes' if pred_class == 1 else 'no'
160
+ pred[f'{task_key}_confidence'] = confidence
161
+ pred[f'{task_key}_prob_yes'] = probs[1].item()
162
+ pred[f'{task_key}_prob_no'] = probs[0].item()
163
+
164
+ elif task_info['type'] == 'multi_class':
165
+ # Multi-class classification
166
+ pred_label = task_info['labels'][pred_class]
167
+ pred[f'{task_key}_prediction'] = pred_label
168
+ pred[f'{task_key}_confidence'] = confidence
169
+
170
+ # Add probabilities for all classes
171
+ for idx, label in enumerate(task_info['labels']):
172
+ pred[f'{task_key}_prob_{label}'] = probs[idx].item()
173
+
174
+ batch_preds.append(pred)
175
+
176
+ predictions.extend(batch_preds)
177
+
178
+ return predictions
179
+
180
+ def calculate_accuracies(predictions, ground_truths, task_config):
181
+ """Calculate accuracies for each task"""
182
+ accuracies = {}
183
+ tasks = {task['key']: task for task in task_config['tasks']}
184
+
185
+ for task_key, task_info in tasks.items():
186
+ pred_key = f'{task_key}_prediction'
187
+
188
+ correct = sum(1 for pred, gt in zip(predictions, ground_truths)
189
+ if pred[pred_key] == gt[task_key])
190
+ total = len(predictions)
191
+ accuracies[f'{task_key}_accuracy'] = correct / total if total > 0 else 0
192
+
193
+ return accuracies
194
+
195
+ def run_validation(data_dir: str = './data', checkpoint_dir: str = './checkpoints',
196
+ output_file: str = './validation_results.json'):
197
+ """Run complete validation and save results"""
198
+ logger.info("Starting validation run...")
199
+
200
+ # Load model and data
201
+ model, processor, device, task_config = load_trained_model(checkpoint_dir)
202
+ samples = load_validation_data(data_dir, task_config)
203
+
204
+ if not samples:
205
+ logger.error("No validation samples found!")
206
+ return
207
+
208
+ # Prepare images for batch processing
209
+ images = []
210
+ for sample in tqdm(samples, desc="Loading images"):
211
+ try:
212
+ img = Image.open(sample['image_path']).convert('RGB')
213
+ images.append(img)
214
+ except Exception as e:
215
+ logger.error(f"Error loading image {sample['image_path']}: {e}")
216
+ images.append(None)
217
+
218
+ # Filter out failed images
219
+ valid_samples = []
220
+ valid_images = []
221
+ for sample, img in zip(samples, images):
222
+ if img is not None:
223
+ valid_samples.append(sample)
224
+ valid_images.append(img)
225
+
226
+ logger.info(f"Running predictions on {len(valid_samples)} valid samples...")
227
+
228
+ # Run predictions
229
+ predictions = predict_batch(model, processor, valid_images, device, task_config)
230
+
231
+ # Calculate accuracies
232
+ ground_truths = [sample['ground_truth'] for sample in valid_samples]
233
+ accuracies = calculate_accuracies(predictions, ground_truths, task_config)
234
+
235
+ # Combine results
236
+ validation_results = []
237
+ for sample, prediction in zip(valid_samples, predictions):
238
+ result = {
239
+ **sample,
240
+ 'predictions': prediction
241
+ }
242
+ validation_results.append(result)
243
+
244
+ # Create final output
245
+ output_data = {
246
+ 'metadata': {
247
+ 'total_samples': len(validation_results),
248
+ 'model_checkpoint': checkpoint_dir,
249
+ 'validation_accuracies': accuracies,
250
+ 'task_config': task_config,
251
+ 'timestamp': pd.Timestamp.now().isoformat()
252
+ },
253
+ 'results': validation_results
254
+ }
255
+
256
+ # Save results
257
+ output_path = Path(output_file)
258
+ with open(output_path, 'w') as f:
259
+ json.dump(output_data, f, indent=2)
260
+
261
+ logger.info(f"Validation results saved to {output_path}")
262
+ logger.info("Validation Accuracies:")
263
+ for key, value in accuracies.items():
264
+ logger.info(f" {key}: {value:.4f}")
265
+
266
+ return output_data
267
+
268
+ def main():
269
+ """Main execution"""
270
+ logger.info("Starting validation runner...")
271
+
272
+ # Check if model exists
273
+ if not Path('./checkpoints/multi_head_siglip2_classifier.pth').exists():
274
+ logger.error("Trained model not found! Run stage 4 first.")
275
+ return
276
+
277
+ # Run validation
278
+ results = run_validation()
279
+
280
+ if results:
281
+ logger.info("Validation completed successfully!")
282
+ else:
283
+ logger.error("Validation failed!")
284
+
285
+ if __name__ == "__main__":
286
+ main()