Spaces:
No application file
No application file
Upload folder using huggingface_hub
Browse files- example.py +203 -0
- stage_1.py +199 -0
- stage_2.py +292 -0
- stage_4.py +506 -0
- streamlit_evaluation_app.py +695 -0
- upload_to_hf.py +112 -0
- validation_runner.py +286 -0
example.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Example inference script for the Multi-Head SigLIP2 Classifier from Hugging Face Hub.
|
4 |
+
|
5 |
+
Usage examples:
|
6 |
+
# Multiple images, single text
|
7 |
+
python example.py --image img1.png --image img2.jpg --repo fal/multihead_cls --text "an example caption"
|
8 |
+
|
9 |
+
# N images, N texts (returns an N x N similarity matrix)
|
10 |
+
python example.py \
|
11 |
+
--image img1.png --image img2.jpg \
|
12 |
+
--text "a cat" --text "a dog" --repo fal/multihead_cls
|
13 |
+
|
14 |
+
Requires: torch, transformers, huggingface_hub, Pillow, click
|
15 |
+
"""
|
16 |
+
|
17 |
+
import json
|
18 |
+
import click
|
19 |
+
import torch
|
20 |
+
from PIL import Image
|
21 |
+
from transformers import AutoProcessor
|
22 |
+
from huggingface_hub import hf_hub_download
|
23 |
+
|
24 |
+
# Local model definition replicated from training for easy inference
|
25 |
+
import torch.nn as nn
|
26 |
+
from transformers import SiglipModel
|
27 |
+
import torch.nn.functional as F
|
28 |
+
|
29 |
+
CKPT = "google/siglip-base-patch16-256"
|
30 |
+
|
31 |
+
class MultiHeadSiglipClassifier(nn.Module):
|
32 |
+
"""Dynamic multi-head classifier based on task configuration"""
|
33 |
+
def __init__(self, task_config: dict, model_name: str = CKPT):
|
34 |
+
super().__init__()
|
35 |
+
self.task_config = task_config
|
36 |
+
self.siglip = SiglipModel.from_pretrained(model_name)
|
37 |
+
|
38 |
+
# Freeze SigLIP parameters
|
39 |
+
for param in self.siglip.parameters():
|
40 |
+
param.requires_grad = False
|
41 |
+
|
42 |
+
# Create classification heads dynamically based on task config
|
43 |
+
hidden_size = self.siglip.config.vision_config.hidden_size
|
44 |
+
self.classification_heads = nn.ModuleDict()
|
45 |
+
|
46 |
+
for task in task_config['tasks']:
|
47 |
+
task_key = task['key']
|
48 |
+
num_classes = len(task['labels'])
|
49 |
+
|
50 |
+
# Create linear layer for this task
|
51 |
+
head = nn.Linear(hidden_size, num_classes)
|
52 |
+
self.classification_heads[task_key] = head
|
53 |
+
|
54 |
+
def forward(self, pixel_values):
|
55 |
+
# Get SigLIP image embeddings only
|
56 |
+
combined_embeds = self.siglip.get_image_features(pixel_values=pixel_values)
|
57 |
+
|
58 |
+
# Apply all classification heads
|
59 |
+
outputs = {}
|
60 |
+
for task_key, head in self.classification_heads.items():
|
61 |
+
outputs[task_key] = head(combined_embeds)
|
62 |
+
|
63 |
+
return outputs
|
64 |
+
|
65 |
+
|
66 |
+
def load_model_from_hf(repo_id: str):
|
67 |
+
"""Load model, processor, and task config from Hugging Face Hub"""
|
68 |
+
# Download task configuration
|
69 |
+
try:
|
70 |
+
task_config_path = hf_hub_download(repo_id=repo_id, filename="task_config.json", repo_type="model")
|
71 |
+
with open(task_config_path, 'r') as f:
|
72 |
+
task_config = json.load(f)
|
73 |
+
except Exception as e:
|
74 |
+
raise RuntimeError(f"Could not load task_config.json from {repo_id}: {e}")
|
75 |
+
|
76 |
+
# Load processor
|
77 |
+
processor = AutoProcessor.from_pretrained(CKPT)
|
78 |
+
|
79 |
+
# Create model with task config
|
80 |
+
model = MultiHeadSiglipClassifier(task_config)
|
81 |
+
|
82 |
+
# Load trained weights
|
83 |
+
try:
|
84 |
+
ckpt_path = hf_hub_download(repo_id=repo_id, filename="model.pth", repo_type="model")
|
85 |
+
state_dict = torch.load(ckpt_path, map_location="cpu")
|
86 |
+
model.load_state_dict(state_dict)
|
87 |
+
except Exception as e:
|
88 |
+
raise RuntimeError(f"Could not load model.pth from {repo_id}: {e}")
|
89 |
+
|
90 |
+
model.eval()
|
91 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
92 |
+
model.to(device)
|
93 |
+
|
94 |
+
return model, processor, device, task_config
|
95 |
+
|
96 |
+
|
97 |
+
def predict_batch(model, processor, device, task_config, image_paths, texts: list[str] | None = None):
|
98 |
+
"""Run predictions on a batch of images using dynamic task configuration"""
|
99 |
+
images = [Image.open(p).convert("RGB") for p in image_paths]
|
100 |
+
if texts is not None and len(texts) == 0:
|
101 |
+
texts = None
|
102 |
+
|
103 |
+
# Process images
|
104 |
+
image_inputs = processor(images=images, return_tensors="pt")
|
105 |
+
pixel_values = image_inputs["pixel_values"].to(device)
|
106 |
+
|
107 |
+
with torch.no_grad():
|
108 |
+
outputs = model(pixel_values)
|
109 |
+
# Compute image embeddings for similarity
|
110 |
+
image_embeds = model.siglip.get_image_features(pixel_values=pixel_values)
|
111 |
+
image_embeds = F.normalize(image_embeds, p=2, dim=-1)
|
112 |
+
|
113 |
+
# Prepare text inputs if provided
|
114 |
+
text_embeds = None
|
115 |
+
input_ids = None
|
116 |
+
attention_mask = None
|
117 |
+
if texts is not None:
|
118 |
+
text_inputs = processor(text=texts, padding="max_length", return_tensors="pt")
|
119 |
+
input_ids = text_inputs["input_ids"].to(device)
|
120 |
+
attention_mask = text_inputs.get("attention_mask")
|
121 |
+
attention_mask = attention_mask.to(device) if attention_mask is not None else None
|
122 |
+
text_embeds = model.siglip.get_text_features(input_ids=input_ids, attention_mask=attention_mask)
|
123 |
+
text_embeds = F.normalize(text_embeds, p=2, dim=-1)
|
124 |
+
|
125 |
+
# Create task mappings
|
126 |
+
tasks = {task['key']: task for task in task_config['tasks']}
|
127 |
+
|
128 |
+
batch_results = []
|
129 |
+
batch_size = pixel_values.shape[0]
|
130 |
+
|
131 |
+
for i in range(batch_size):
|
132 |
+
item = {"image": str(image_paths[i])}
|
133 |
+
|
134 |
+
# Process each task dynamically
|
135 |
+
for task_key, task_info in tasks.items():
|
136 |
+
logits = outputs[task_key][i]
|
137 |
+
probs = torch.softmax(logits, dim=0)
|
138 |
+
pred_idx = torch.argmax(probs).item()
|
139 |
+
|
140 |
+
if task_info['type'] == 'binary':
|
141 |
+
# Binary classification
|
142 |
+
item[f"{task_key}_prediction"] = task_info['labels'][pred_idx]
|
143 |
+
item[f"{task_key}_confidence"] = float(probs[pred_idx].item())
|
144 |
+
item[f"{task_key}_prob_yes"] = float(probs[1].item()) if len(task_info['labels']) > 1 else 0.0
|
145 |
+
item[f"{task_key}_prob_no"] = float(probs[0].item())
|
146 |
+
|
147 |
+
elif task_info['type'] == 'multi_class':
|
148 |
+
# Multi-class classification
|
149 |
+
item[f"{task_key}_prediction"] = task_info['labels'][pred_idx]
|
150 |
+
item[f"{task_key}_confidence"] = float(probs[pred_idx].item())
|
151 |
+
|
152 |
+
# Add probabilities for all classes
|
153 |
+
for idx, label in enumerate(task_info['labels']):
|
154 |
+
item[f"{task_key}_prob_{label}"] = float(probs[idx].item())
|
155 |
+
|
156 |
+
batch_results.append(item)
|
157 |
+
|
158 |
+
cosine_matrix = None
|
159 |
+
|
160 |
+
if input_ids is not None:
|
161 |
+
# These embeds are already L2-normalized inside SigLIP forward
|
162 |
+
cosine = torch.matmul(image_embeds, text_embeds.T)
|
163 |
+
cosine_matrix = cosine.cpu().tolist()
|
164 |
+
|
165 |
+
return {
|
166 |
+
"images": [str(p) for p in image_paths],
|
167 |
+
"texts": texts or [],
|
168 |
+
"task_config": task_config,
|
169 |
+
"predictions": batch_results,
|
170 |
+
"cosine_similarity": cosine_matrix,
|
171 |
+
}
|
172 |
+
|
173 |
+
|
174 |
+
@click.command()
|
175 |
+
@click.option("--image", "images", multiple=True, type=click.Path(exists=True, dir_okay=False, readable=True), help="Path(s) to image file(s). Can be passed multiple times.")
|
176 |
+
@click.option("--repo", default="fal/multihead_cls", show_default=True, help="Hugging Face repo id with model checkpoint.")
|
177 |
+
@click.option("--text", "texts", multiple=True, help="Text prompt(s). Can be passed multiple times to build an N x N image-text similarity matrix.")
|
178 |
+
@click.option("--show-tasks", is_flag=True, help="Show available classification tasks and exit.")
|
179 |
+
def cli(images, repo, texts, show_tasks):
|
180 |
+
"""Multi-head SigLIP2 classifier inference from Hugging Face Hub"""
|
181 |
+
|
182 |
+
# Load model and task config
|
183 |
+
model, processor, device, task_config = load_model_from_hf(repo)
|
184 |
+
|
185 |
+
if show_tasks:
|
186 |
+
click.echo("Available classification tasks:")
|
187 |
+
for i, task in enumerate(task_config['tasks'], 1):
|
188 |
+
click.echo(f" {i}. {task['name']} ({task['key']})")
|
189 |
+
click.echo(f" Type: {task['type']}")
|
190 |
+
click.echo(f" Labels: {', '.join(task['labels'])}")
|
191 |
+
click.echo(f" Description: {task['description']}")
|
192 |
+
click.echo()
|
193 |
+
return
|
194 |
+
|
195 |
+
if not images:
|
196 |
+
images = ("img.png",)
|
197 |
+
|
198 |
+
results = predict_batch(model, processor, device, task_config, list(images), texts=list(texts) if texts else None)
|
199 |
+
click.echo(json.dumps(results, indent=2))
|
200 |
+
|
201 |
+
|
202 |
+
if __name__ == "__main__":
|
203 |
+
cli()
|
stage_1.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Stage 1: Data Loading and Image Downloading
|
4 |
+
Downloads and preprocesses top 2000 images from parquet file
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import json
|
9 |
+
import requests
|
10 |
+
import pandas as pd
|
11 |
+
from PIL import Image
|
12 |
+
from io import BytesIO
|
13 |
+
import concurrent.futures
|
14 |
+
from pathlib import Path
|
15 |
+
import time
|
16 |
+
import logging
|
17 |
+
import numpy as np
|
18 |
+
from typing import Tuple
|
19 |
+
|
20 |
+
# Set up logging
|
21 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
def setup_environment():
|
25 |
+
"""Setup data directory"""
|
26 |
+
os.makedirs('./data', exist_ok=True)
|
27 |
+
os.makedirs('./data/images', exist_ok=True)
|
28 |
+
os.makedirs('./data/metadata', exist_ok=True)
|
29 |
+
return True
|
30 |
+
|
31 |
+
def load_and_sample_data(parquet_path: str, n_samples: int = 2000) -> pd.DataFrame:
|
32 |
+
"""Load parquet file and sample top N rows"""
|
33 |
+
logger.info(f"Loading data from {parquet_path}")
|
34 |
+
df = pd.read_parquet(parquet_path)
|
35 |
+
logger.info(f"Loaded {len(df)} rows, sampling top {n_samples}")
|
36 |
+
return df.head(n_samples)
|
37 |
+
|
38 |
+
def has_white_edges(img: Image.Image, threshold: int = 240) -> bool:
|
39 |
+
"""Check if image has 3 or more white edges (mean RGB > threshold)"""
|
40 |
+
try:
|
41 |
+
img_array = np.array(img)
|
42 |
+
height, width = img_array.shape[:2]
|
43 |
+
|
44 |
+
# Define edge thickness (check 5 pixels from each edge)
|
45 |
+
edge_thickness = 5
|
46 |
+
|
47 |
+
# Get edges
|
48 |
+
top_edge = img_array[:edge_thickness, :].mean(axis=(0, 1))
|
49 |
+
bottom_edge = img_array[-edge_thickness:, :].mean(axis=(0, 1))
|
50 |
+
left_edge = img_array[:, :edge_thickness].mean(axis=(0, 1))
|
51 |
+
right_edge = img_array[:, -edge_thickness:].mean(axis=(0, 1))
|
52 |
+
|
53 |
+
# Check if edge is white (all RGB channels > threshold)
|
54 |
+
edges = [top_edge, bottom_edge, left_edge, right_edge]
|
55 |
+
white_edges = sum(1 for edge in edges if np.all(edge > threshold))
|
56 |
+
|
57 |
+
return white_edges >= 3
|
58 |
+
except Exception as e:
|
59 |
+
logger.debug(f"Error checking white edges: {e}")
|
60 |
+
return False
|
61 |
+
|
62 |
+
def download_and_process_image(url: str, target_size: int = 256) -> Image.Image:
|
63 |
+
"""Download image and resize with center crop, skip if has white edges"""
|
64 |
+
try:
|
65 |
+
response = requests.get(url, timeout=10, headers={'User-Agent': 'Mozilla/5.0'})
|
66 |
+
response.raise_for_status()
|
67 |
+
|
68 |
+
|
69 |
+
img = Image.open(BytesIO(response.content)).convert('RGB')
|
70 |
+
|
71 |
+
# Check for white edges before processing
|
72 |
+
if has_white_edges(img):
|
73 |
+
logger.debug(f"Skipping image with white edges: {url}")
|
74 |
+
return None
|
75 |
+
|
76 |
+
# Resize and center crop to target_size x target_size
|
77 |
+
width, height = img.size
|
78 |
+
min_side = min(width, height)
|
79 |
+
scale = target_size / min_side
|
80 |
+
|
81 |
+
new_width = int(width * scale)
|
82 |
+
new_height = int(height * scale)
|
83 |
+
|
84 |
+
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
85 |
+
|
86 |
+
# Center crop
|
87 |
+
left = (new_width - target_size) // 2
|
88 |
+
top = (new_height - target_size) // 2
|
89 |
+
right = left + target_size
|
90 |
+
bottom = top + target_size
|
91 |
+
|
92 |
+
img = img.crop((left, top, right, bottom))
|
93 |
+
|
94 |
+
# Double-check after processing
|
95 |
+
if has_white_edges(img):
|
96 |
+
logger.debug(f"Skipping processed image with white edges: {url}")
|
97 |
+
return None
|
98 |
+
|
99 |
+
return img
|
100 |
+
except Exception as e:
|
101 |
+
logger.error(f"Error downloading {url}: {e}")
|
102 |
+
return None
|
103 |
+
|
104 |
+
def process_single_image(args: Tuple[int, str, str, str]) -> bool:
|
105 |
+
"""Download and save a single image"""
|
106 |
+
idx, url, hash_val, caption = args
|
107 |
+
|
108 |
+
try:
|
109 |
+
# Download and process image
|
110 |
+
image = download_and_process_image(url)
|
111 |
+
if image is None:
|
112 |
+
logger.debug(f"Skipped image {idx} (white edges or download error)")
|
113 |
+
return False
|
114 |
+
|
115 |
+
# Save image
|
116 |
+
image_path = f'./data/images/img_{idx}.png'
|
117 |
+
image.save(image_path)
|
118 |
+
|
119 |
+
# Save metadata for next stage
|
120 |
+
metadata = {
|
121 |
+
"idx": idx,
|
122 |
+
"caption": caption,
|
123 |
+
"url": url,
|
124 |
+
"hash": hash_val,
|
125 |
+
"image_path": image_path
|
126 |
+
}
|
127 |
+
|
128 |
+
metadata_path = f'./data/metadata/meta_{idx}.json'
|
129 |
+
with open(metadata_path, 'w') as f:
|
130 |
+
json.dump(metadata, f, indent=2)
|
131 |
+
|
132 |
+
logger.info(f"Downloaded and saved image {idx}")
|
133 |
+
return True
|
134 |
+
|
135 |
+
except Exception as e:
|
136 |
+
logger.error(f"Error processing image {idx}: {e}")
|
137 |
+
return False
|
138 |
+
|
139 |
+
def download_images(df: pd.DataFrame, max_workers: int = 20):
|
140 |
+
"""Download all images with parallel processing"""
|
141 |
+
logger.info(f"Starting image download with {max_workers} workers...")
|
142 |
+
|
143 |
+
args_list = [(i, row['url'], row['hash'], row['caption'])
|
144 |
+
for i, (_, row) in enumerate(df.iterrows())]
|
145 |
+
|
146 |
+
successful = 0
|
147 |
+
skipped_white_edges = 0
|
148 |
+
|
149 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
150 |
+
futures = [executor.submit(process_single_image, args) for args in args_list]
|
151 |
+
|
152 |
+
for i, future in enumerate(concurrent.futures.as_completed(futures)):
|
153 |
+
if future.result():
|
154 |
+
successful += 1
|
155 |
+
else:
|
156 |
+
skipped_white_edges += 1
|
157 |
+
|
158 |
+
# Progress logging every 100 images
|
159 |
+
if (i + 1) % 100 == 0:
|
160 |
+
logger.info(f"Processed {i + 1}/{len(args_list)} images (successful: {successful}, skipped: {skipped_white_edges})")
|
161 |
+
|
162 |
+
# Minimal rate limiting for high concurrency
|
163 |
+
time.sleep(0.01)
|
164 |
+
|
165 |
+
logger.info(f"Download complete: {successful}/{len(args_list)} images downloaded, {skipped_white_edges} skipped (white edges)")
|
166 |
+
|
167 |
+
# Save summary
|
168 |
+
summary = {
|
169 |
+
"total_images": len(args_list),
|
170 |
+
"successful_downloads": successful,
|
171 |
+
"skipped_white_edges": skipped_white_edges,
|
172 |
+
"download_rate": f"{successful/len(args_list)*100:.1f}%",
|
173 |
+
"stage": "download_complete"
|
174 |
+
}
|
175 |
+
|
176 |
+
with open('./data/stage1_summary.json', 'w') as f:
|
177 |
+
json.dump(summary, f, indent=2)
|
178 |
+
|
179 |
+
def main():
|
180 |
+
"""Main execution for Stage 1"""
|
181 |
+
logger.info("Starting Stage 1: Data Loading and Image Downloading...")
|
182 |
+
|
183 |
+
# Setup
|
184 |
+
setup_environment()
|
185 |
+
|
186 |
+
# Load data
|
187 |
+
parquet_path = '/home/fal/partiprompt_clip/curated_part_00000.parquet'
|
188 |
+
df = load_and_sample_data(parquet_path, n_samples=5000)
|
189 |
+
|
190 |
+
# Save the dataframe for other stages
|
191 |
+
df.to_pickle('./data/sampled_data.pkl')
|
192 |
+
|
193 |
+
# Download images with optimized settings
|
194 |
+
download_images(df, max_workers=30)
|
195 |
+
|
196 |
+
logger.info("Stage 1 completed successfully!")
|
197 |
+
|
198 |
+
if __name__ == "__main__":
|
199 |
+
main()
|
stage_2.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Stage 2: Gemini Vision Classification
|
4 |
+
Classifies images using Google Gemini with 5 classification tasks
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import json
|
9 |
+
from PIL import Image
|
10 |
+
from io import BytesIO
|
11 |
+
import concurrent.futures
|
12 |
+
from pathlib import Path
|
13 |
+
import time
|
14 |
+
import logging
|
15 |
+
from typing import Dict, Any
|
16 |
+
import mimetypes
|
17 |
+
import random
|
18 |
+
|
19 |
+
# Gemini SDK
|
20 |
+
from google import genai
|
21 |
+
from google.genai.errors import ServerError
|
22 |
+
from google.genai.types import (
|
23 |
+
Blob, Part, Content, GenerateContentConfig,
|
24 |
+
)
|
25 |
+
|
26 |
+
# Set up logging
|
27 |
+
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
28 |
+
logger = logging.getLogger(__name__)
|
29 |
+
|
30 |
+
GEMINI_API_KEY_FALLBACK = "AIzaSyBCgkB2nRaRNgbl06MBu1I_xHiuXSUQMHA"
|
31 |
+
|
32 |
+
|
33 |
+
def check_api_key():
|
34 |
+
"""Ensure Google API key is set for Gemini client."""
|
35 |
+
if not os.getenv('GOOGLE_API_KEY'):
|
36 |
+
# Use provided key as fallback if env not set
|
37 |
+
os.environ['GOOGLE_API_KEY'] = GEMINI_API_KEY_FALLBACK
|
38 |
+
return True
|
39 |
+
|
40 |
+
def _guess_mime_type(image_path: str) -> str:
|
41 |
+
guessed, _ = mimetypes.guess_type(image_path)
|
42 |
+
if guessed:
|
43 |
+
return guessed
|
44 |
+
try:
|
45 |
+
with Image.open(image_path) as im:
|
46 |
+
fmt = (im.format or '').lower()
|
47 |
+
if fmt in ('jpeg', 'jpg'):
|
48 |
+
return 'image/jpeg'
|
49 |
+
if fmt == 'png':
|
50 |
+
return 'image/png'
|
51 |
+
if fmt == 'webp':
|
52 |
+
return 'image/webp'
|
53 |
+
if fmt == 'gif':
|
54 |
+
return 'image/gif'
|
55 |
+
except Exception:
|
56 |
+
pass
|
57 |
+
return 'application/octet-stream'
|
58 |
+
|
59 |
+
def _gemini_call_with_retry(contents, cfg, max_attempts: int = 5):
|
60 |
+
"""Call Gemini with retries on server/errors."""
|
61 |
+
api_key = os.getenv('GOOGLE_API_KEY') or GEMINI_API_KEY_FALLBACK
|
62 |
+
for attempt in range(max_attempts):
|
63 |
+
try:
|
64 |
+
client = genai.Client(api_key=api_key)
|
65 |
+
return client.models.generate_content(
|
66 |
+
model="models/gemini-2.5-flash",
|
67 |
+
contents=contents,
|
68 |
+
config=cfg,
|
69 |
+
)
|
70 |
+
except ServerError as e:
|
71 |
+
sleep_s = (2 ** attempt) + random.random()
|
72 |
+
logger.warning(f"Gemini server error attempt {attempt+1}/{max_attempts}: {e}; retrying in {sleep_s:.1f}s")
|
73 |
+
time.sleep(sleep_s)
|
74 |
+
except Exception as e:
|
75 |
+
sleep_s = (2 ** attempt) + random.random()
|
76 |
+
logger.warning(f"Gemini error attempt {attempt+1}/{max_attempts}: {e}; retrying in {sleep_s:.1f}s")
|
77 |
+
time.sleep(sleep_s)
|
78 |
+
raise RuntimeError(f"Persistent Gemini error after {max_attempts} tries")
|
79 |
+
|
80 |
+
|
81 |
+
def classify_image_with_gemini(image_path: str, caption: str, max_retries: int = 3) -> Dict[str, Any]:
|
82 |
+
"""Use Google Gemini to classify an image with structured JSON output."""
|
83 |
+
prompt = f"""
|
84 |
+
Analyze this image with caption: "{caption}"
|
85 |
+
|
86 |
+
Please answer the following 5 classification questions and respond ONLY with valid JSON:
|
87 |
+
|
88 |
+
1. Overall Description.
|
89 |
+
2. Is the image product display / low quality advertisement / e-commerce product? Answer: "yes" or "no"
|
90 |
+
3. Is the image computer screenshot with many text overlays? Answer: "yes" or "no"
|
91 |
+
4. In what category is the image? Choose one from: "animals", "artifacts", "people", "outdoor_scenes", "illustrations", "vehicles", "food_and_beverage", "arts", "abstract", "produce_and_plants", "indoor_scenes"
|
92 |
+
5. Would you say the image is interesting? Answer: "yes" or "no"
|
93 |
+
6. Do you think the photo/image was made by a professional photographer? Answer: "yes" or "no"
|
94 |
+
|
95 |
+
IMPORTANT: Respond with ONLY a valid JSON object with these exact keys. Do not include any other text or explanation:
|
96 |
+
|
97 |
+
{{
|
98 |
+
"overall_description": "...",
|
99 |
+
"is_product_advertisement": "yes",
|
100 |
+
"is_screenshot_with_text": "no",
|
101 |
+
"category": "animals",
|
102 |
+
"is_interesting": "no",
|
103 |
+
"is_professional": "yes"
|
104 |
+
}}
|
105 |
+
"""
|
106 |
+
|
107 |
+
default_response = {
|
108 |
+
"overall_description": "...",
|
109 |
+
"is_product_advertisement": "...",
|
110 |
+
"is_screenshot_with_text": "...",
|
111 |
+
"category": "...",
|
112 |
+
"is_interesting": "...",
|
113 |
+
"is_professional": "..."
|
114 |
+
}
|
115 |
+
|
116 |
+
try:
|
117 |
+
with open(image_path, 'rb') as f:
|
118 |
+
image_bytes = f.read()
|
119 |
+
mime_type = _guess_mime_type(image_path)
|
120 |
+
|
121 |
+
image_blob = Blob(mime_type=mime_type, data=image_bytes)
|
122 |
+
user_content = Content(
|
123 |
+
role="user",
|
124 |
+
parts=[
|
125 |
+
Part(text=prompt),
|
126 |
+
Part(inline_data=image_blob),
|
127 |
+
],
|
128 |
+
)
|
129 |
+
contents = [user_content]
|
130 |
+
cfg = GenerateContentConfig(max_output_tokens=2500, temperature=0)
|
131 |
+
|
132 |
+
resp = _gemini_call_with_retry(contents, cfg, max_attempts=max_retries)
|
133 |
+
logger.debug(f"Gemini response type: {type(resp)}")
|
134 |
+
|
135 |
+
# Detailed debugging of response structure
|
136 |
+
logger.debug(f"Response.text: {getattr(resp, 'text', 'NO_TEXT_ATTR')}")
|
137 |
+
logger.debug(f"Response.candidates: {getattr(resp, 'candidates', 'NO_CANDIDATES_ATTR')}")
|
138 |
+
if hasattr(resp, 'candidates') and resp.candidates:
|
139 |
+
logger.debug(f"Number of candidates: {len(resp.candidates)}")
|
140 |
+
for i, candidate in enumerate(resp.candidates):
|
141 |
+
logger.debug(f"Candidate {i}: {candidate}")
|
142 |
+
if hasattr(candidate, 'content'):
|
143 |
+
logger.debug(f"Candidate {i} content: {candidate.content}")
|
144 |
+
if hasattr(candidate.content, 'parts'):
|
145 |
+
logger.debug(f"Candidate {i} parts: {candidate.content.parts}")
|
146 |
+
|
147 |
+
# Check for prompt_feedback which might indicate filtering
|
148 |
+
if hasattr(resp, 'prompt_feedback'):
|
149 |
+
logger.debug(f"Prompt feedback: {resp.prompt_feedback}")
|
150 |
+
|
151 |
+
# Extract text from Gemini response
|
152 |
+
content_text = ""
|
153 |
+
try:
|
154 |
+
# Try the .text property first
|
155 |
+
if hasattr(resp, 'text') and resp.text:
|
156 |
+
content_text = resp.text
|
157 |
+
logger.debug(f"Got text from .text property: {content_text[:100]}...")
|
158 |
+
else:
|
159 |
+
# Fallback: extract from candidates
|
160 |
+
if resp.candidates and len(resp.candidates) > 0:
|
161 |
+
candidate = resp.candidates[0]
|
162 |
+
if hasattr(candidate, 'content') and candidate.content:
|
163 |
+
if hasattr(candidate.content, 'parts') and candidate.content.parts:
|
164 |
+
for part in candidate.content.parts:
|
165 |
+
if hasattr(part, 'text') and part.text:
|
166 |
+
content_text += part.text
|
167 |
+
logger.debug(f"Got text from candidate part: {part.text[:100]}...")
|
168 |
+
except Exception as e:
|
169 |
+
logger.error(f"Error extracting text from Gemini response: {e}")
|
170 |
+
raise e
|
171 |
+
|
172 |
+
if not content_text:
|
173 |
+
logger.error(f"Empty response from Gemini")
|
174 |
+
return default_response
|
175 |
+
|
176 |
+
content_text = content_text.strip()
|
177 |
+
start_idx = content_text.find('{')
|
178 |
+
end_idx = content_text.rfind('}') + 1
|
179 |
+
if start_idx == -1 or end_idx == 0:
|
180 |
+
logger.error(f"No JSON found in response: {content_text}")
|
181 |
+
return default_response
|
182 |
+
|
183 |
+
json_content = content_text[start_idx:end_idx]
|
184 |
+
classification = json.loads(json_content)
|
185 |
+
|
186 |
+
required_keys = [
|
187 |
+
"overall_description",
|
188 |
+
"is_product_advertisement",
|
189 |
+
"is_screenshot_with_text",
|
190 |
+
"category",
|
191 |
+
"is_interesting",
|
192 |
+
"is_professional",
|
193 |
+
]
|
194 |
+
missing_keys = [key for key in required_keys if key not in classification]
|
195 |
+
if missing_keys:
|
196 |
+
logger.warning(f"Missing keys in classification: {missing_keys}")
|
197 |
+
for key in missing_keys:
|
198 |
+
classification[key] = default_response[key]
|
199 |
+
|
200 |
+
return classification
|
201 |
+
except json.JSONDecodeError as e:
|
202 |
+
logger.error(f"JSON parsing error: {e}")
|
203 |
+
return default_response
|
204 |
+
except Exception as e:
|
205 |
+
logger.error(f"Gemini classification error: {e}")
|
206 |
+
return default_response
|
207 |
+
|
208 |
+
def classify_single_image(metadata_file: Path) -> bool:
|
209 |
+
"""Classify a single image and save results"""
|
210 |
+
try:
|
211 |
+
# Load metadata
|
212 |
+
with open(metadata_file, 'r', encoding='utf-8') as f:
|
213 |
+
metadata = json.load(f)
|
214 |
+
|
215 |
+
idx = metadata['idx']
|
216 |
+
image_path = metadata['image_path']
|
217 |
+
caption = metadata['caption']
|
218 |
+
|
219 |
+
# Check if image exists
|
220 |
+
if not os.path.exists(image_path):
|
221 |
+
logger.error(f"Image not found: {image_path}")
|
222 |
+
return False
|
223 |
+
|
224 |
+
# Classify with Gemini
|
225 |
+
classification = classify_image_with_gemini(image_path, caption)
|
226 |
+
|
227 |
+
# Add classification to metadata
|
228 |
+
metadata['classification'] = classification
|
229 |
+
metadata['stage2_complete'] = True
|
230 |
+
|
231 |
+
# Save updated metadata
|
232 |
+
new_metadata_file = metadata_file.with_name(f'meta_{idx}_stage2.json')
|
233 |
+
with open(new_metadata_file, 'w', encoding='utf-8') as f:
|
234 |
+
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
235 |
+
|
236 |
+
logger.info(f"Classified image {idx}")
|
237 |
+
return True
|
238 |
+
|
239 |
+
except Exception as e:
|
240 |
+
logger.error(f"Error classifying {metadata_file}: {e}")
|
241 |
+
return False
|
242 |
+
|
243 |
+
def classify_all_images(max_workers: int = 2):
|
244 |
+
"""Classify all downloaded images with parallel processing"""
|
245 |
+
logger.info("Starting image classification...")
|
246 |
+
|
247 |
+
# Get all metadata files
|
248 |
+
metadata_dir = Path('./data/metadata')
|
249 |
+
metadata_files = list(metadata_dir.glob('meta_*.json'))
|
250 |
+
|
251 |
+
if not metadata_files:
|
252 |
+
logger.error("No metadata files found. Run stage 1 first.")
|
253 |
+
return
|
254 |
+
|
255 |
+
successful = 0
|
256 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
257 |
+
futures = [executor.submit(classify_single_image, meta_file) for meta_file in metadata_files]
|
258 |
+
|
259 |
+
for future in concurrent.futures.as_completed(futures):
|
260 |
+
if future.result():
|
261 |
+
successful += 1
|
262 |
+
|
263 |
+
# Rate limiting for API calls
|
264 |
+
time.sleep(1.0) # 1 second between API calls to avoid rate limits
|
265 |
+
|
266 |
+
logger.info(f"Successfully classified {successful}/{len(metadata_files)} images")
|
267 |
+
|
268 |
+
# Save summary
|
269 |
+
summary = {
|
270 |
+
"total_images": len(metadata_files),
|
271 |
+
"successful_classifications": successful,
|
272 |
+
"stage": "classification_complete"
|
273 |
+
}
|
274 |
+
|
275 |
+
with open('./data/stage2_summary.json', 'w') as f:
|
276 |
+
json.dump(summary, f, indent=2)
|
277 |
+
|
278 |
+
def main():
|
279 |
+
"""Main execution for Stage 2"""
|
280 |
+
logger.info("Starting Stage 2: Gemini Vision Classification...")
|
281 |
+
|
282 |
+
# Check API key
|
283 |
+
if not check_api_key():
|
284 |
+
return
|
285 |
+
|
286 |
+
# Classify images
|
287 |
+
classify_all_images(max_workers=64) # Reduced to avoid rate limits
|
288 |
+
|
289 |
+
logger.info("Stage 2 completed successfully!")
|
290 |
+
|
291 |
+
if __name__ == "__main__":
|
292 |
+
main()
|
stage_4.py
ADDED
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Stage 4: SigLIP v2 Multi-Head Classifier Training
|
4 |
+
Trains a SigLIP v2-based multi-head classifier on pseudo-labeled data
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import json
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.optim as optim
|
12 |
+
from torch.utils.data import Dataset, DataLoader
|
13 |
+
from transformers import SiglipModel, AutoProcessor
|
14 |
+
import numpy as np
|
15 |
+
from PIL import Image
|
16 |
+
from pathlib import Path
|
17 |
+
import logging
|
18 |
+
from typing import Dict, List, Any
|
19 |
+
import pickle
|
20 |
+
import matplotlib.pyplot as plt
|
21 |
+
from torch.optim.lr_scheduler import LambdaLR
|
22 |
+
|
23 |
+
# Set up logging
|
24 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
CKPT = "google/siglip-base-patch16-256"
|
28 |
+
|
29 |
+
def load_task_config(config_path: str = './task_config.json'):
|
30 |
+
"""Load task configuration from JSON file"""
|
31 |
+
if not os.path.exists(config_path):
|
32 |
+
raise FileNotFoundError(f"Task configuration not found: {config_path}")
|
33 |
+
|
34 |
+
with open(config_path, 'r') as f:
|
35 |
+
config = json.load(f)
|
36 |
+
|
37 |
+
logger.info(f"Loaded task configuration with {len(config['tasks'])} tasks")
|
38 |
+
return config
|
39 |
+
|
40 |
+
class MultiHeadDataset(Dataset):
|
41 |
+
"""Dataset for multi-head classification with configurable tasks"""
|
42 |
+
def __init__(self, data_dir: str, processor, task_config: Dict):
|
43 |
+
self.data_dir = Path(data_dir)
|
44 |
+
self.processor = processor
|
45 |
+
self.task_config = task_config
|
46 |
+
|
47 |
+
# Load all metadata files from stage 2 (with _stage2 suffix)
|
48 |
+
metadata_dir = self.data_dir / 'metadata'
|
49 |
+
if not metadata_dir.exists():
|
50 |
+
raise FileNotFoundError("Metadata directory not found. Run stages 1 and 2 first.")
|
51 |
+
|
52 |
+
metadata_files = list(metadata_dir.glob('meta_*_stage2.json'))
|
53 |
+
if not metadata_files:
|
54 |
+
raise FileNotFoundError("No stage 2 metadata files found. Run stage 2 first.")
|
55 |
+
|
56 |
+
# Load all samples
|
57 |
+
self.samples = []
|
58 |
+
skipped_incomplete = 0
|
59 |
+
|
60 |
+
for meta_file in metadata_files:
|
61 |
+
try:
|
62 |
+
with open(meta_file, 'r') as f:
|
63 |
+
metadata = json.load(f)
|
64 |
+
|
65 |
+
# Check if classification is complete
|
66 |
+
if not metadata.get('stage2_complete', False):
|
67 |
+
logger.warning(f"Skipping {meta_file} - classification not complete")
|
68 |
+
skipped_incomplete += 1
|
69 |
+
continue
|
70 |
+
|
71 |
+
# Check if classification contains incomplete data (empty or "..." values)
|
72 |
+
classification = metadata.get('classification', {})
|
73 |
+
if not classification or self._is_incomplete_classification(classification):
|
74 |
+
logger.warning(f"Skipping {meta_file} - incomplete classification data")
|
75 |
+
skipped_incomplete += 1
|
76 |
+
continue
|
77 |
+
|
78 |
+
# Check if image exists
|
79 |
+
image_path = metadata['image_path']
|
80 |
+
if not os.path.exists(image_path):
|
81 |
+
logger.warning(f"Image not found: {image_path}")
|
82 |
+
skipped_incomplete += 1
|
83 |
+
continue
|
84 |
+
|
85 |
+
self.samples.append(metadata)
|
86 |
+
|
87 |
+
except Exception as e:
|
88 |
+
logger.error(f"Error loading {meta_file}: {e}")
|
89 |
+
skipped_incomplete += 1
|
90 |
+
|
91 |
+
# Create label mappings from task config
|
92 |
+
self.label_mappings = {}
|
93 |
+
for task in self.task_config['tasks']:
|
94 |
+
if task['type'] == 'multi_class':
|
95 |
+
self.label_mappings[task['key']] = {
|
96 |
+
label: idx for idx, label in enumerate(task['labels'])
|
97 |
+
}
|
98 |
+
|
99 |
+
if skipped_incomplete > 0:
|
100 |
+
logger.warning(f"Skipped {skipped_incomplete} incomplete samples")
|
101 |
+
logger.info(f"Loaded {len(self.samples)} valid samples for training")
|
102 |
+
|
103 |
+
def _is_incomplete_classification(self, classification: Dict) -> bool:
|
104 |
+
"""Check if classification contains incomplete data (empty or '...' values)"""
|
105 |
+
required_tasks = [task['key'] for task in self.task_config['tasks']]
|
106 |
+
|
107 |
+
for task_key in required_tasks:
|
108 |
+
if task_key not in classification:
|
109 |
+
return True
|
110 |
+
|
111 |
+
value = classification[task_key]
|
112 |
+
# Check for incomplete markers
|
113 |
+
if not value or value == "..." or value == "" or value is None:
|
114 |
+
return True
|
115 |
+
|
116 |
+
return False
|
117 |
+
|
118 |
+
def __len__(self):
|
119 |
+
return len(self.samples)
|
120 |
+
|
121 |
+
def __getitem__(self, idx):
|
122 |
+
sample = self.samples[idx]
|
123 |
+
|
124 |
+
# Load image
|
125 |
+
image = Image.open(sample['image_path']).convert('RGB')
|
126 |
+
|
127 |
+
# Process image only
|
128 |
+
inputs = self.processor(
|
129 |
+
images=image,
|
130 |
+
return_tensors="pt"
|
131 |
+
)
|
132 |
+
|
133 |
+
# Convert classifications to labels based on task config
|
134 |
+
classification = sample['classification']
|
135 |
+
labels = {}
|
136 |
+
|
137 |
+
for task in self.task_config['tasks']:
|
138 |
+
task_key = task['key']
|
139 |
+
if task['type'] == 'binary':
|
140 |
+
# Binary tasks: convert yes/no to 1/0
|
141 |
+
labels[task_key] = 1 if classification[task_key] == 'yes' else 0
|
142 |
+
elif task['type'] == 'multi_class':
|
143 |
+
# Multi-class tasks: convert to index
|
144 |
+
label_str = classification[task_key]
|
145 |
+
labels[task_key] = self.label_mappings[task_key].get(label_str, 0) # default to first class
|
146 |
+
|
147 |
+
return {
|
148 |
+
'pixel_values': inputs['pixel_values'].squeeze(0),
|
149 |
+
'labels': labels,
|
150 |
+
'metadata': {
|
151 |
+
'idx': sample['idx'],
|
152 |
+
'caption': sample['caption'],
|
153 |
+
'image_path': sample['image_path']
|
154 |
+
}
|
155 |
+
}
|
156 |
+
|
157 |
+
class MultiHeadSiglipClassifier(nn.Module):
|
158 |
+
"""SigLIP-based multi-head classifier with configurable tasks"""
|
159 |
+
def __init__(self, task_config: Dict, model_name: str = CKPT):
|
160 |
+
super().__init__()
|
161 |
+
|
162 |
+
self.task_config = task_config
|
163 |
+
self.siglip = SiglipModel.from_pretrained(model_name)
|
164 |
+
|
165 |
+
# Freeze SigLIP parameters initially
|
166 |
+
for param in self.siglip.parameters():
|
167 |
+
param.requires_grad = False
|
168 |
+
|
169 |
+
# Create classification heads dynamically based on task config
|
170 |
+
hidden_size = self.siglip.config.vision_config.hidden_size
|
171 |
+
self.classification_heads = nn.ModuleDict()
|
172 |
+
|
173 |
+
for task in task_config['tasks']:
|
174 |
+
task_key = task['key']
|
175 |
+
num_classes = len(task['labels'])
|
176 |
+
|
177 |
+
# Create linear layer for this task
|
178 |
+
head = nn.Linear(hidden_size, num_classes)
|
179 |
+
|
180 |
+
# Initialize with zeros
|
181 |
+
head.weight.data.zero_()
|
182 |
+
head.bias.data.zero_()
|
183 |
+
|
184 |
+
self.classification_heads[task_key] = head
|
185 |
+
|
186 |
+
logger.info(f"Created {len(self.classification_heads)} classification heads")
|
187 |
+
|
188 |
+
def forward(self, pixel_values):
|
189 |
+
# Get SigLIP image embeddings only
|
190 |
+
combined_embeds = self.siglip.get_image_features(pixel_values=pixel_values)
|
191 |
+
|
192 |
+
# Apply all classification heads
|
193 |
+
outputs = {}
|
194 |
+
for task_key, head in self.classification_heads.items():
|
195 |
+
outputs[task_key] = head(combined_embeds)
|
196 |
+
|
197 |
+
return outputs
|
198 |
+
|
199 |
+
def calculate_accuracy(predictions, labels):
|
200 |
+
"""Calculate accuracy for binary/multi-class predictions"""
|
201 |
+
pred_classes = torch.argmax(predictions, dim=1)
|
202 |
+
correct = (pred_classes == labels).float()
|
203 |
+
return correct.mean().item()
|
204 |
+
|
205 |
+
def plot_validation_accuracies(history, task_config, save_path='./checkpoints/validation_accuracies.png'):
|
206 |
+
"""Create and save validation accuracy plots"""
|
207 |
+
tasks = [task['key'] for task in task_config['tasks']]
|
208 |
+
task_names = [task['name'] for task in task_config['tasks']]
|
209 |
+
|
210 |
+
# Calculate grid size
|
211 |
+
n_tasks = len(tasks)
|
212 |
+
n_cols = 3
|
213 |
+
n_rows = (n_tasks + n_cols - 1) // n_cols # Ceiling division
|
214 |
+
|
215 |
+
fig, axes = plt.subplots(n_rows, n_cols, figsize=(18, 6 * n_rows))
|
216 |
+
fig.suptitle('Training Progress Dashboard', fontsize=16, fontweight='bold')
|
217 |
+
|
218 |
+
# Flatten axes for easier indexing
|
219 |
+
if n_rows == 1:
|
220 |
+
axes = [axes] if n_cols == 1 else axes
|
221 |
+
else:
|
222 |
+
axes = axes.flatten()
|
223 |
+
|
224 |
+
epochs = range(1, len(history['val_accuracy'][tasks[0]]) + 1)
|
225 |
+
colors = plt.cm.Set1(np.linspace(0, 1, n_tasks))
|
226 |
+
|
227 |
+
# Plot individual validation accuracies
|
228 |
+
for i, (task_key, task_name, color) in enumerate(zip(tasks, task_names, colors)):
|
229 |
+
if i < len(axes):
|
230 |
+
axes[i].plot(epochs, history['val_accuracy'][task_key],
|
231 |
+
label=task_name, marker='o', color=color, linewidth=2, markersize=4)
|
232 |
+
axes[i].set_xlabel('Epoch')
|
233 |
+
axes[i].set_ylabel('Validation Accuracy')
|
234 |
+
axes[i].set_title(f'{task_name} Validation Accuracy')
|
235 |
+
axes[i].grid(True, alpha=0.3)
|
236 |
+
axes[i].set_ylim(0, 1)
|
237 |
+
|
238 |
+
# Hide unused subplots
|
239 |
+
for i in range(n_tasks, len(axes)):
|
240 |
+
axes[i].set_visible(False)
|
241 |
+
|
242 |
+
plt.tight_layout()
|
243 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
244 |
+
plt.close()
|
245 |
+
|
246 |
+
logger.info(f"Validation accuracy plots saved to {save_path}")
|
247 |
+
|
248 |
+
# Calculate summary statistics
|
249 |
+
best_accs = [max(history['val_accuracy'][task]) for task in tasks]
|
250 |
+
final_accs = [history['val_accuracy'][task][-1] for task in tasks]
|
251 |
+
|
252 |
+
return best_accs, final_accs
|
253 |
+
|
254 |
+
def train_multi_head_classifier(data_dir: str, task_config_path: str = './task_config.json',
|
255 |
+
epochs: int = 30, batch_size: int = 4):
|
256 |
+
"""Train the multi-head SigLIP v2 classifier"""
|
257 |
+
logger.info("Starting multi-head classifier training...")
|
258 |
+
|
259 |
+
# Load task configuration
|
260 |
+
task_config = load_task_config(task_config_path)
|
261 |
+
|
262 |
+
# Create checkpoints directory
|
263 |
+
checkpoint_dir = Path('./checkpoints')
|
264 |
+
checkpoint_dir.mkdir(exist_ok=True)
|
265 |
+
logger.info(f"Checkpoints will be saved to: {checkpoint_dir}")
|
266 |
+
|
267 |
+
# Save task config to checkpoints for inference
|
268 |
+
with open(checkpoint_dir / 'task_config.json', 'w') as f:
|
269 |
+
json.dump(task_config, f, indent=2)
|
270 |
+
|
271 |
+
# Load processor and model
|
272 |
+
processor = AutoProcessor.from_pretrained(CKPT)
|
273 |
+
model = MultiHeadSiglipClassifier(task_config, model_name=CKPT)
|
274 |
+
|
275 |
+
# Dataset and dataloader
|
276 |
+
dataset = MultiHeadDataset(data_dir, processor, task_config)
|
277 |
+
if len(dataset) == 0:
|
278 |
+
logger.error("No training data found!")
|
279 |
+
return
|
280 |
+
|
281 |
+
# Split dataset (simple train/val split)
|
282 |
+
train_size = int(0.8 * len(dataset))
|
283 |
+
val_size = len(dataset) - train_size
|
284 |
+
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
|
285 |
+
|
286 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
287 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
288 |
+
|
289 |
+
# Setup training
|
290 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
291 |
+
logger.info(f"Using device: {device}")
|
292 |
+
model.to(device)
|
293 |
+
|
294 |
+
# Optimizer and loss functions
|
295 |
+
# Get model parameters that require gradients (only classification heads)
|
296 |
+
params = []
|
297 |
+
for name, param in model.named_parameters():
|
298 |
+
if param.requires_grad:
|
299 |
+
params.append(param)
|
300 |
+
|
301 |
+
optimizer = optim.AdamW(params, lr=1e-2)
|
302 |
+
|
303 |
+
# Linear cooldown LR scheduler
|
304 |
+
def linear_cooldown(epoch):
|
305 |
+
return max(0.1, 1.0 - (epoch / epochs))
|
306 |
+
|
307 |
+
scheduler = LambdaLR(optimizer, lr_lambda=linear_cooldown)
|
308 |
+
criterion = nn.CrossEntropyLoss()
|
309 |
+
|
310 |
+
# Initialize training history
|
311 |
+
history = {
|
312 |
+
'train_loss': [],
|
313 |
+
'val_loss': [],
|
314 |
+
'learning_rates': [],
|
315 |
+
'val_accuracy': {task['key']: [] for task in task_config['tasks']},
|
316 |
+
'epoch_val_accuracy': []
|
317 |
+
}
|
318 |
+
|
319 |
+
# Training loop
|
320 |
+
for epoch in range(epochs):
|
321 |
+
# Training phase
|
322 |
+
model.train()
|
323 |
+
total_train_loss = 0
|
324 |
+
|
325 |
+
for batch_idx, batch in enumerate(train_loader):
|
326 |
+
optimizer.zero_grad()
|
327 |
+
|
328 |
+
# Move to device
|
329 |
+
pixel_values = batch['pixel_values'].to(device)
|
330 |
+
|
331 |
+
# Forward pass
|
332 |
+
outputs = model(pixel_values)
|
333 |
+
|
334 |
+
# Calculate losses for each task
|
335 |
+
losses = []
|
336 |
+
for task in task_config['tasks']:
|
337 |
+
task_key = task['key']
|
338 |
+
labels = batch['labels'][task_key].to(device)
|
339 |
+
loss = criterion(outputs[task_key], labels)
|
340 |
+
losses.append(loss)
|
341 |
+
|
342 |
+
# Total loss
|
343 |
+
total_batch_loss = sum(losses)
|
344 |
+
total_batch_loss.backward()
|
345 |
+
optimizer.step()
|
346 |
+
|
347 |
+
total_train_loss += total_batch_loss.item()
|
348 |
+
|
349 |
+
if batch_idx % 10 == 0:
|
350 |
+
logger.info(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {total_batch_loss.item():.4f}")
|
351 |
+
|
352 |
+
avg_train_loss = total_train_loss / len(train_loader)
|
353 |
+
history['train_loss'].append(avg_train_loss)
|
354 |
+
|
355 |
+
# Record learning rate
|
356 |
+
current_lr = optimizer.param_groups[0]['lr']
|
357 |
+
history['learning_rates'].append(current_lr)
|
358 |
+
|
359 |
+
# Validation phase
|
360 |
+
model.eval()
|
361 |
+
total_val_loss = 0
|
362 |
+
val_accuracies = {task['key']: [] for task in task_config['tasks']}
|
363 |
+
|
364 |
+
with torch.no_grad():
|
365 |
+
for batch in val_loader:
|
366 |
+
pixel_values = batch['pixel_values'].to(device)
|
367 |
+
|
368 |
+
outputs = model(pixel_values)
|
369 |
+
|
370 |
+
# Calculate validation losses and accuracies
|
371 |
+
losses = []
|
372 |
+
for task in task_config['tasks']:
|
373 |
+
task_key = task['key']
|
374 |
+
labels = batch['labels'][task_key].to(device)
|
375 |
+
loss = criterion(outputs[task_key], labels)
|
376 |
+
losses.append(loss)
|
377 |
+
|
378 |
+
# Calculate accuracy
|
379 |
+
acc = calculate_accuracy(outputs[task_key], labels)
|
380 |
+
val_accuracies[task_key].append(acc)
|
381 |
+
|
382 |
+
total_val_loss += sum(losses).item()
|
383 |
+
|
384 |
+
avg_val_loss = total_val_loss / len(val_loader)
|
385 |
+
history['val_loss'].append(avg_val_loss)
|
386 |
+
|
387 |
+
# Calculate average accuracies
|
388 |
+
epoch_accuracies = {}
|
389 |
+
for task in task_config['tasks']:
|
390 |
+
task_key = task['key']
|
391 |
+
avg_acc = np.mean(val_accuracies[task_key])
|
392 |
+
epoch_accuracies[task_key] = avg_acc
|
393 |
+
history['val_accuracy'][task_key].append(avg_acc)
|
394 |
+
|
395 |
+
history['epoch_val_accuracy'].append(epoch_accuracies.copy())
|
396 |
+
|
397 |
+
logger.info(f"Epoch {epoch+1}/{epochs}")
|
398 |
+
logger.info(f" Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
|
399 |
+
logger.info(f" Learning Rate: {current_lr:.6f}")
|
400 |
+
logger.info(f" Val Accuracies: {epoch_accuracies}")
|
401 |
+
|
402 |
+
# Step the learning rate scheduler
|
403 |
+
scheduler.step()
|
404 |
+
|
405 |
+
# Create comprehensive checkpoint
|
406 |
+
checkpoint = {
|
407 |
+
'epoch': epochs,
|
408 |
+
'model_state_dict': model.state_dict(),
|
409 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
410 |
+
'scheduler_state_dict': scheduler.state_dict(),
|
411 |
+
'history': history,
|
412 |
+
'final_accuracies': epoch_accuracies,
|
413 |
+
'task_config': task_config
|
414 |
+
}
|
415 |
+
|
416 |
+
# Save the trained model and checkpoint
|
417 |
+
torch.save(model.state_dict(), checkpoint_dir / 'multi_head_siglip2_classifier.pth')
|
418 |
+
torch.save(checkpoint, checkpoint_dir / 'training_checkpoint.pth')
|
419 |
+
logger.info(f"Model saved to {checkpoint_dir / 'multi_head_siglip2_classifier.pth'}")
|
420 |
+
logger.info(f"Full checkpoint saved to {checkpoint_dir / 'training_checkpoint.pth'}")
|
421 |
+
|
422 |
+
# Save processor for inference
|
423 |
+
processor.save_pretrained(checkpoint_dir / 'siglip2_processor')
|
424 |
+
logger.info(f"Processor saved to {checkpoint_dir / 'siglip2_processor'}")
|
425 |
+
|
426 |
+
# Save training history as JSON
|
427 |
+
with open(checkpoint_dir / 'training_history.json', 'w') as f:
|
428 |
+
json_history = {}
|
429 |
+
for key, value in history.items():
|
430 |
+
if key == 'val_accuracy':
|
431 |
+
json_history[key] = {task: [float(acc) for acc in accs] for task, accs in value.items()}
|
432 |
+
elif key == 'epoch_val_accuracy':
|
433 |
+
json_history[key] = [{task: float(acc) for task, acc in epoch.items()} for epoch in value]
|
434 |
+
else:
|
435 |
+
json_history[key] = [float(x) for x in value]
|
436 |
+
json.dump(json_history, f, indent=2)
|
437 |
+
logger.info(f"Training history saved to {checkpoint_dir / 'training_history.json'}")
|
438 |
+
|
439 |
+
# Generate and save validation accuracy plots
|
440 |
+
best_accs, final_accs = plot_validation_accuracies(history, task_config, checkpoint_dir / 'validation_accuracies.png')
|
441 |
+
|
442 |
+
# Save detailed validation accuracy summary
|
443 |
+
val_summary = {
|
444 |
+
'best_accuracies': {
|
445 |
+
task['key']: float(max(history['val_accuracy'][task['key']]))
|
446 |
+
for task in task_config['tasks']
|
447 |
+
},
|
448 |
+
'final_accuracies': {task: float(acc) for task, acc in epoch_accuracies.items()},
|
449 |
+
'average_best_accuracy': float(np.mean(best_accs)),
|
450 |
+
'average_final_accuracy': float(np.mean(final_accs)),
|
451 |
+
'improvement_per_task': {
|
452 |
+
task['key']: float(history['val_accuracy'][task['key']][-1] - history['val_accuracy'][task['key']][0])
|
453 |
+
for task in task_config['tasks']
|
454 |
+
}
|
455 |
+
}
|
456 |
+
|
457 |
+
with open(checkpoint_dir / 'validation_summary.json', 'w') as f:
|
458 |
+
json.dump(val_summary, f, indent=2)
|
459 |
+
logger.info(f"Validation summary saved to {checkpoint_dir / 'validation_summary.json'}")
|
460 |
+
|
461 |
+
# Save final training summary
|
462 |
+
final_summary = {
|
463 |
+
"model_type": "SigLIP2 Multi-Head Classifier",
|
464 |
+
"training_samples": len(train_dataset),
|
465 |
+
"validation_samples": len(val_dataset),
|
466 |
+
"epochs": epochs,
|
467 |
+
"final_train_loss": avg_train_loss,
|
468 |
+
"final_val_loss": avg_val_loss,
|
469 |
+
"final_accuracies": epoch_accuracies,
|
470 |
+
"task_config": task_config,
|
471 |
+
"classification_heads": {
|
472 |
+
task['key']: f"{task['type']} - {task['description']}"
|
473 |
+
for task in task_config['tasks']
|
474 |
+
}
|
475 |
+
}
|
476 |
+
|
477 |
+
with open(checkpoint_dir / 'stage4_summary.json', 'w') as f:
|
478 |
+
json.dump(final_summary, f, indent=2)
|
479 |
+
logger.info(f"Stage 4 summary saved to {checkpoint_dir / 'stage4_summary.json'}")
|
480 |
+
|
481 |
+
# Log summary of saved artifacts
|
482 |
+
logger.info("="*60)
|
483 |
+
logger.info("TRAINING COMPLETE - ARTIFACTS SAVED:")
|
484 |
+
logger.info(f"📁 Checkpoint Directory: {checkpoint_dir}")
|
485 |
+
logger.info(f"🤖 Model Weights: multi_head_siglip2_classifier.pth")
|
486 |
+
logger.info(f"💾 Full Checkpoint: training_checkpoint.pth")
|
487 |
+
logger.info(f"🔧 Processor: siglip2_processor/")
|
488 |
+
logger.info(f"⚙️ Task Config: task_config.json")
|
489 |
+
logger.info(f"📊 Training History: training_history.json")
|
490 |
+
logger.info(f"📈 Validation Plots: validation_accuracies.png")
|
491 |
+
logger.info(f"📋 Validation Summary: validation_summary.json")
|
492 |
+
logger.info(f"📄 Stage Summary: stage4_summary.json")
|
493 |
+
logger.info("="*60)
|
494 |
+
|
495 |
+
def main():
|
496 |
+
"""Main execution for Stage 4"""
|
497 |
+
logger.info("Starting Stage 4: SigLIP v2 Multi-Head Training...")
|
498 |
+
|
499 |
+
# Train classifier
|
500 |
+
train_multi_head_classifier('./data', epochs=10, batch_size=2)
|
501 |
+
|
502 |
+
logger.info("Stage 4 completed successfully!")
|
503 |
+
logger.info("🎉 Complete pipeline finished! Check ./checkpoints/ for all training artifacts.")
|
504 |
+
|
505 |
+
if __name__ == "__main__":
|
506 |
+
main()
|
streamlit_evaluation_app.py
ADDED
@@ -0,0 +1,695 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Streamlit Data Viewer and Model Evaluation System
|
4 |
+
Interactive dashboard for exploring validation results with threshold filtering
|
5 |
+
"""
|
6 |
+
|
7 |
+
import streamlit as st
|
8 |
+
import json
|
9 |
+
import pandas as pd
|
10 |
+
import numpy as np
|
11 |
+
from PIL import Image
|
12 |
+
import plotly.express as px
|
13 |
+
import plotly.graph_objects as go
|
14 |
+
from plotly.subplots import make_subplots
|
15 |
+
import os
|
16 |
+
from pathlib import Path
|
17 |
+
import subprocess
|
18 |
+
import sys
|
19 |
+
from rapidocr import RapidOCR
|
20 |
+
from matplotlib import pyplot as plt
|
21 |
+
|
22 |
+
# Page config
|
23 |
+
st.set_page_config(
|
24 |
+
page_title="Pseudoable Classifier Evaluation Dashboard",
|
25 |
+
page_icon="🔍",
|
26 |
+
layout="wide",
|
27 |
+
initial_sidebar_state="expanded"
|
28 |
+
)
|
29 |
+
|
30 |
+
# Custom CSS for better styling
|
31 |
+
st.markdown("""
|
32 |
+
<style>
|
33 |
+
.main-header {
|
34 |
+
font-size: 3rem;
|
35 |
+
font-weight: bold;
|
36 |
+
color: #1f77b4;
|
37 |
+
text-align: center;
|
38 |
+
margin-bottom: 2rem;
|
39 |
+
padding: 1rem;
|
40 |
+
background: linear-gradient(90deg, #f0f8ff, #e6f3ff);
|
41 |
+
border-radius: 10px;
|
42 |
+
}
|
43 |
+
.metric-card {
|
44 |
+
background-color: #f8f9fa;
|
45 |
+
padding: 1rem;
|
46 |
+
border-radius: 8px;
|
47 |
+
border-left: 4px solid #1f77b4;
|
48 |
+
margin: 0.5rem 0;
|
49 |
+
}
|
50 |
+
.filter-section {
|
51 |
+
background-color: #ffffff;
|
52 |
+
padding: 1.5rem;
|
53 |
+
border-radius: 10px;
|
54 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
55 |
+
margin-bottom: 2rem;
|
56 |
+
}
|
57 |
+
.image-container {
|
58 |
+
border: 2px solid #e6e6e6;
|
59 |
+
border-radius: 8px;
|
60 |
+
padding: 10px;
|
61 |
+
margin: 10px 0;
|
62 |
+
background-color: #fafafa;
|
63 |
+
}
|
64 |
+
.prediction-badge {
|
65 |
+
display: inline-block;
|
66 |
+
padding: 0.25rem 0.5rem;
|
67 |
+
border-radius: 15px;
|
68 |
+
font-size: 0.8rem;
|
69 |
+
font-weight: bold;
|
70 |
+
margin: 0.2rem;
|
71 |
+
}
|
72 |
+
.correct-prediction {
|
73 |
+
background-color: #d4edda;
|
74 |
+
color: #155724;
|
75 |
+
}
|
76 |
+
.incorrect-prediction {
|
77 |
+
background-color: #f8d7da;
|
78 |
+
color: #721c24;
|
79 |
+
}
|
80 |
+
</style>
|
81 |
+
""", unsafe_allow_html=True)
|
82 |
+
|
83 |
+
@st.cache_data
|
84 |
+
def load_task_config(config_path: str = './task_config.json'):
|
85 |
+
"""Load task configuration from JSON file"""
|
86 |
+
if not os.path.exists(config_path):
|
87 |
+
# Try to load from checkpoints directory
|
88 |
+
checkpoint_config = './checkpoints/task_config.json'
|
89 |
+
if os.path.exists(checkpoint_config):
|
90 |
+
config_path = checkpoint_config
|
91 |
+
else:
|
92 |
+
return None
|
93 |
+
|
94 |
+
with open(config_path, 'r') as f:
|
95 |
+
config = json.load(f)
|
96 |
+
return config
|
97 |
+
|
98 |
+
@st.cache_data
|
99 |
+
def load_validation_results(file_path: str = './validation_results.json'):
|
100 |
+
"""Load validation results from JSON file"""
|
101 |
+
if not os.path.exists(file_path):
|
102 |
+
return None
|
103 |
+
|
104 |
+
with open(file_path, 'r') as f:
|
105 |
+
data = json.load(f)
|
106 |
+
return data
|
107 |
+
|
108 |
+
@st.cache_resource
|
109 |
+
def get_ocr_engine():
|
110 |
+
"""Initialize and cache OCR engine"""
|
111 |
+
return RapidOCR()
|
112 |
+
|
113 |
+
@st.cache_data
|
114 |
+
def extract_text_from_image(image_path: str):
|
115 |
+
"""Extract text from image using OCR"""
|
116 |
+
try:
|
117 |
+
engine = get_ocr_engine()
|
118 |
+
result = engine(image_path)
|
119 |
+
|
120 |
+
# Handle new RapidOCR output format
|
121 |
+
if result and hasattr(result, 'txts') and result.txts:
|
122 |
+
texts = result.txts
|
123 |
+
return {
|
124 |
+
'text': ' '.join(texts) if texts else '',
|
125 |
+
'num_text_blocks': len(texts),
|
126 |
+
'has_text': len(texts) > 0
|
127 |
+
}
|
128 |
+
elif result and isinstance(result, (list, tuple)) and len(result) > 0:
|
129 |
+
# Fallback for older format
|
130 |
+
texts = []
|
131 |
+
for item in result:
|
132 |
+
if len(item) >= 2:
|
133 |
+
texts.append(item[1])
|
134 |
+
|
135 |
+
return {
|
136 |
+
'text': ' '.join(texts) if texts else '',
|
137 |
+
'num_text_blocks': len(texts),
|
138 |
+
'has_text': len(texts) > 0
|
139 |
+
}
|
140 |
+
else:
|
141 |
+
return {
|
142 |
+
'text': '',
|
143 |
+
'num_text_blocks': 0,
|
144 |
+
'has_text': False
|
145 |
+
}
|
146 |
+
except Exception as e:
|
147 |
+
return {
|
148 |
+
'text': f'OCR Error: {str(e)}',
|
149 |
+
'num_text_blocks': 0,
|
150 |
+
'has_text': False
|
151 |
+
}
|
152 |
+
|
153 |
+
@st.cache_data
|
154 |
+
def process_validation_data(validation_data, task_config):
|
155 |
+
"""Process validation data into DataFrame for easier filtering"""
|
156 |
+
if not validation_data or not task_config:
|
157 |
+
return None
|
158 |
+
|
159 |
+
rows = []
|
160 |
+
tasks = {task['key']: task for task in task_config['tasks']}
|
161 |
+
|
162 |
+
for result in validation_data['results']:
|
163 |
+
row = {
|
164 |
+
'idx': result['idx'],
|
165 |
+
'caption': result['caption'],
|
166 |
+
'image_path': result['image_path'],
|
167 |
+
'url': result['url'],
|
168 |
+
'hash': result['hash']
|
169 |
+
}
|
170 |
+
|
171 |
+
# Ground truth and predictions
|
172 |
+
gt = result['ground_truth']
|
173 |
+
pred = result['predictions']
|
174 |
+
|
175 |
+
# Process each task dynamically
|
176 |
+
for task_key, task_info in tasks.items():
|
177 |
+
# Ground truth
|
178 |
+
row[f'gt_{task_key}'] = gt[task_key]
|
179 |
+
|
180 |
+
# Predictions
|
181 |
+
row[f'pred_{task_key}'] = pred[f'{task_key}_prediction']
|
182 |
+
row[f'conf_{task_key}'] = pred[f'{task_key}_confidence']
|
183 |
+
|
184 |
+
# For binary tasks, also include probability for 'yes'
|
185 |
+
if task_info['type'] == 'binary':
|
186 |
+
row[f'prob_{task_key}_yes'] = pred.get(f'{task_key}_prob_yes', 0.5)
|
187 |
+
|
188 |
+
# Correctness
|
189 |
+
row[f'correct_{task_key}'] = gt[task_key] == pred[f'{task_key}_prediction']
|
190 |
+
|
191 |
+
rows.append(row)
|
192 |
+
|
193 |
+
return pd.DataFrame(rows)
|
194 |
+
|
195 |
+
def run_validation_if_needed():
|
196 |
+
"""Run validation if results don't exist"""
|
197 |
+
if not os.path.exists('./validation_results.json'):
|
198 |
+
st.warning("Validation results not found. Running validation...")
|
199 |
+
|
200 |
+
# Check if model exists
|
201 |
+
if not os.path.exists('./checkpoints/multi_head_siglip2_classifier.pth'):
|
202 |
+
st.error("❌ Trained model not found! Please run the training pipeline first.")
|
203 |
+
st.code("python stage_4.py")
|
204 |
+
return False
|
205 |
+
|
206 |
+
# Run validation
|
207 |
+
with st.spinner("Running model on validation set... This may take a few minutes."):
|
208 |
+
try:
|
209 |
+
result = subprocess.run([sys.executable, 'validation_runner.py'],
|
210 |
+
capture_output=True, text=True)
|
211 |
+
if result.returncode == 0:
|
212 |
+
st.success("✅ Validation completed successfully!")
|
213 |
+
st.rerun()
|
214 |
+
else:
|
215 |
+
st.error(f"❌ Validation failed: {result.stderr}")
|
216 |
+
return False
|
217 |
+
except Exception as e:
|
218 |
+
st.error(f"❌ Error running validation: {e}")
|
219 |
+
return False
|
220 |
+
|
221 |
+
return True
|
222 |
+
|
223 |
+
def create_overview_metrics(df, validation_data, task_config):
|
224 |
+
"""Create overview metrics section"""
|
225 |
+
st.markdown("## 📊 Overview Metrics")
|
226 |
+
|
227 |
+
tasks = [task['key'] for task in task_config['tasks']]
|
228 |
+
|
229 |
+
# Basic stats
|
230 |
+
col1, col2, col3, col4 = st.columns(4)
|
231 |
+
|
232 |
+
with col1:
|
233 |
+
st.metric("Total Samples", len(df))
|
234 |
+
|
235 |
+
with col2:
|
236 |
+
avg_confidence = np.mean([df[f'conf_{task}'].mean() for task in tasks])
|
237 |
+
st.metric("Avg Confidence", f"{avg_confidence:.3f}")
|
238 |
+
|
239 |
+
with col3:
|
240 |
+
overall_accuracy = np.mean([df[f'correct_{task}'].mean() for task in tasks])
|
241 |
+
st.metric("Overall Accuracy", f"{overall_accuracy:.3f}")
|
242 |
+
|
243 |
+
with col4:
|
244 |
+
if validation_data and 'metadata' in validation_data:
|
245 |
+
model_accuracies = validation_data['metadata']['validation_accuracies']
|
246 |
+
model_avg = np.mean(list(model_accuracies.values()))
|
247 |
+
st.metric("Model Accuracy", f"{model_avg:.3f}")
|
248 |
+
|
249 |
+
# Detailed accuracies
|
250 |
+
st.markdown("### 🎯 Accuracy per Classification Task")
|
251 |
+
|
252 |
+
# Create dynamic columns based on number of tasks
|
253 |
+
n_tasks = len(tasks)
|
254 |
+
n_cols = min(5, n_tasks) # Max 5 columns
|
255 |
+
acc_cols = st.columns(n_cols)
|
256 |
+
|
257 |
+
for i, task in enumerate(tasks):
|
258 |
+
task_info = next(t for t in task_config['tasks'] if t['key'] == task)
|
259 |
+
with acc_cols[i % n_cols]:
|
260 |
+
accuracy = df[f'correct_{task}'].mean()
|
261 |
+
st.metric(task_info['name'], f"{accuracy:.3f}")
|
262 |
+
|
263 |
+
def create_confidence_distribution_plot(df, task_config):
|
264 |
+
"""Create confidence distribution plots"""
|
265 |
+
tasks = [task['key'] for task in task_config['tasks']]
|
266 |
+
task_names = [task['name'] for task in task_config['tasks']]
|
267 |
+
|
268 |
+
n_tasks = len(tasks)
|
269 |
+
n_cols = 3
|
270 |
+
n_rows = (n_tasks + n_cols - 1) // n_cols
|
271 |
+
|
272 |
+
fig = make_subplots(
|
273 |
+
rows=n_rows, cols=n_cols,
|
274 |
+
subplot_titles=task_names,
|
275 |
+
specs=[[{"secondary_y": False} for _ in range(n_cols)] for _ in range(n_rows)]
|
276 |
+
)
|
277 |
+
|
278 |
+
colors = plt.cm.Set1(np.linspace(0, 1, n_tasks))
|
279 |
+
|
280 |
+
for i, (task_key, color) in enumerate(zip(tasks, colors)):
|
281 |
+
row = (i // n_cols) + 1
|
282 |
+
col = (i % n_cols) + 1
|
283 |
+
|
284 |
+
fig.add_trace(
|
285 |
+
go.Histogram(
|
286 |
+
x=df[f'conf_{task_key}'],
|
287 |
+
nbinsx=20,
|
288 |
+
name=f'{task_key}',
|
289 |
+
marker_color=f'rgba({color[0]*255:.0f},{color[1]*255:.0f},{color[2]*255:.0f},0.7)',
|
290 |
+
opacity=0.7
|
291 |
+
),
|
292 |
+
row=row, col=col
|
293 |
+
)
|
294 |
+
|
295 |
+
fig.update_layout(
|
296 |
+
title="Confidence Score Distributions",
|
297 |
+
showlegend=False,
|
298 |
+
height=200 * n_rows + 100
|
299 |
+
)
|
300 |
+
|
301 |
+
return fig
|
302 |
+
|
303 |
+
def apply_filters(df, task_config):
|
304 |
+
"""Apply user-defined filters to the dataframe"""
|
305 |
+
st.markdown("## 🔍 Filter Data")
|
306 |
+
|
307 |
+
tasks = {task['key']: task for task in task_config['tasks']}
|
308 |
+
|
309 |
+
# Create filter sidebar
|
310 |
+
with st.sidebar:
|
311 |
+
st.markdown("### Task Confidence Filters")
|
312 |
+
|
313 |
+
# Confidence thresholds for each task
|
314 |
+
confidence_filters = {}
|
315 |
+
for task_key, task_info in tasks.items():
|
316 |
+
if task_info['type'] == 'multi_class':
|
317 |
+
# Only show confidence threshold for multi-class tasks
|
318 |
+
confidence_filters[task_key] = st.slider(
|
319 |
+
f"{task_info['name']} Confidence",
|
320 |
+
0.0, 1.0, 0.5, 0.01,
|
321 |
+
key=f"conf_{task_key}"
|
322 |
+
)
|
323 |
+
|
324 |
+
st.markdown("### Content Filters")
|
325 |
+
|
326 |
+
# Category filters (for multi-class tasks)
|
327 |
+
category_filters = {}
|
328 |
+
for task_key, task_info in tasks.items():
|
329 |
+
if task_info['type'] == 'multi_class':
|
330 |
+
available_values = df[f'gt_{task_key}'].unique().tolist()
|
331 |
+
selected_values = st.multiselect(
|
332 |
+
f"Ground Truth {task_info['name']}",
|
333 |
+
available_values,
|
334 |
+
default=available_values,
|
335 |
+
key=f"gt_{task_key}_filter"
|
336 |
+
)
|
337 |
+
category_filters[task_key] = selected_values
|
338 |
+
|
339 |
+
# Binary prediction filters
|
340 |
+
st.markdown("**Filter by Predictions:**")
|
341 |
+
prediction_filters = {}
|
342 |
+
for task_key, task_info in tasks.items():
|
343 |
+
if task_info['type'] == 'binary':
|
344 |
+
filter_value = st.selectbox(
|
345 |
+
f"{task_info['name']}:",
|
346 |
+
["All", "Yes only", "No only"],
|
347 |
+
key=f"pred_{task_key}_filter"
|
348 |
+
)
|
349 |
+
prediction_filters[task_key] = filter_value
|
350 |
+
|
351 |
+
# Correctness filter
|
352 |
+
st.markdown("**Filter by Correctness:**")
|
353 |
+
correctness_filter = st.selectbox(
|
354 |
+
"Show only:",
|
355 |
+
["All predictions", "Correct predictions", "Incorrect predictions"]
|
356 |
+
)
|
357 |
+
|
358 |
+
# OCR filters (if screenshot task exists)
|
359 |
+
has_screenshot_task = any(task['key'] == 'is_screenshot_with_text' for task in task_config['tasks'])
|
360 |
+
if has_screenshot_task:
|
361 |
+
st.markdown("**Filter by Text Content:**")
|
362 |
+
ocr_filter = st.selectbox(
|
363 |
+
"Text Content:",
|
364 |
+
["All images", "Images with text", "Images without text"],
|
365 |
+
key="ocr_filter"
|
366 |
+
)
|
367 |
+
enable_ocr = st.checkbox("Enable OCR text extraction", value=True)
|
368 |
+
else:
|
369 |
+
ocr_filter = "All images"
|
370 |
+
enable_ocr = False
|
371 |
+
|
372 |
+
# Apply filters
|
373 |
+
filtered_df = df.copy()
|
374 |
+
|
375 |
+
# Confidence filters
|
376 |
+
for task_key, threshold in confidence_filters.items():
|
377 |
+
filtered_df = filtered_df[filtered_df[f'conf_{task_key}'] >= threshold]
|
378 |
+
|
379 |
+
# Category filters
|
380 |
+
for task_key, selected_values in category_filters.items():
|
381 |
+
filtered_df = filtered_df[filtered_df[f'gt_{task_key}'].isin(selected_values)]
|
382 |
+
|
383 |
+
# Binary prediction filters
|
384 |
+
for task_key, filter_value in prediction_filters.items():
|
385 |
+
if filter_value == "Yes only":
|
386 |
+
filtered_df = filtered_df[filtered_df[f'pred_{task_key}'] == 'yes']
|
387 |
+
elif filter_value == "No only":
|
388 |
+
filtered_df = filtered_df[filtered_df[f'pred_{task_key}'] == 'no']
|
389 |
+
|
390 |
+
# Correctness filter
|
391 |
+
if correctness_filter == "Correct predictions":
|
392 |
+
correct_mask = True
|
393 |
+
for task_key in tasks.keys():
|
394 |
+
correct_mask = correct_mask & filtered_df[f'correct_{task_key}']
|
395 |
+
filtered_df = filtered_df[correct_mask]
|
396 |
+
elif correctness_filter == "Incorrect predictions":
|
397 |
+
correct_mask = True
|
398 |
+
for task_key in tasks.keys():
|
399 |
+
correct_mask = correct_mask & filtered_df[f'correct_{task_key}']
|
400 |
+
filtered_df = filtered_df[~correct_mask]
|
401 |
+
|
402 |
+
# Show filter results
|
403 |
+
st.info(f"Filtered to {len(filtered_df)} samples (from {len(df)} total)")
|
404 |
+
|
405 |
+
return filtered_df, ocr_filter, enable_ocr
|
406 |
+
|
407 |
+
def display_sample_images(df, task_config, ocr_filter="All images", enable_ocr=True):
|
408 |
+
"""Display sample images with predictions and ground truth"""
|
409 |
+
st.markdown("## 🖼️ Sample Images")
|
410 |
+
|
411 |
+
if len(df) == 0:
|
412 |
+
st.warning("No images match the current filters.")
|
413 |
+
return
|
414 |
+
|
415 |
+
tasks = {task['key']: task for task in task_config['tasks']}
|
416 |
+
|
417 |
+
# Add controls for image display
|
418 |
+
col1, col2, col3 = st.columns([2, 1, 1])
|
419 |
+
|
420 |
+
with col1:
|
421 |
+
max_images = st.slider(
|
422 |
+
"Number of images to display",
|
423 |
+
min_value=10,
|
424 |
+
max_value=min(200, len(df)),
|
425 |
+
value=min(50, len(df)),
|
426 |
+
step=10
|
427 |
+
)
|
428 |
+
|
429 |
+
with col2:
|
430 |
+
sort_by = st.selectbox(
|
431 |
+
"Sort by:",
|
432 |
+
["Original order", "Confidence (low to high)", "Confidence (high to low)"]
|
433 |
+
)
|
434 |
+
|
435 |
+
with col3:
|
436 |
+
cols_per_row = st.selectbox("Images per row:", [2, 3, 4], index=1)
|
437 |
+
|
438 |
+
# Sort dataframe if requested
|
439 |
+
task_keys = list(tasks.keys())
|
440 |
+
if sort_by == "Confidence (low to high)":
|
441 |
+
avg_conf = sum(df[f'conf_{task}'] for task in task_keys) / len(task_keys)
|
442 |
+
display_df = df.iloc[avg_conf.argsort()].head(max_images)
|
443 |
+
elif sort_by == "Confidence (high to low)":
|
444 |
+
avg_conf = sum(df[f'conf_{task}'] for task in task_keys) / len(task_keys)
|
445 |
+
display_df = df.iloc[avg_conf.argsort()[::-1]].head(max_images)
|
446 |
+
else:
|
447 |
+
display_df = df.head(max_images)
|
448 |
+
|
449 |
+
# Apply OCR filtering if needed
|
450 |
+
if enable_ocr and ocr_filter != "All images":
|
451 |
+
st.info("🔍 Applying OCR filtering... This may take a moment for many images.")
|
452 |
+
|
453 |
+
ocr_results = []
|
454 |
+
progress_bar = st.progress(0)
|
455 |
+
|
456 |
+
for idx, (_, row) in enumerate(display_df.iterrows()):
|
457 |
+
if os.path.exists(row['image_path']):
|
458 |
+
ocr_result = extract_text_from_image(row['image_path'])
|
459 |
+
ocr_results.append(ocr_result['has_text'])
|
460 |
+
else:
|
461 |
+
ocr_results.append(False)
|
462 |
+
|
463 |
+
progress_bar.progress((idx + 1) / len(display_df))
|
464 |
+
|
465 |
+
progress_bar.empty()
|
466 |
+
|
467 |
+
# Filter based on OCR results
|
468 |
+
if ocr_filter == "Images with text":
|
469 |
+
mask = ocr_results
|
470 |
+
else: # "Images without text"
|
471 |
+
mask = [not has_text for has_text in ocr_results]
|
472 |
+
|
473 |
+
display_df = display_df[mask].reset_index(drop=True)
|
474 |
+
st.success(f"OCR filtering complete. Found {len(display_df)} images matching criteria.")
|
475 |
+
|
476 |
+
# Display images
|
477 |
+
for i in range(0, len(display_df), cols_per_row):
|
478 |
+
cols = st.columns(cols_per_row)
|
479 |
+
|
480 |
+
for j in range(cols_per_row):
|
481 |
+
if i + j < len(display_df):
|
482 |
+
row = display_df.iloc[i + j]
|
483 |
+
|
484 |
+
with cols[j]:
|
485 |
+
# Load and display image
|
486 |
+
try:
|
487 |
+
if os.path.exists(row['image_path']):
|
488 |
+
img = Image.open(row['image_path'])
|
489 |
+
st.image(img, caption=f"Sample {row['idx']}", use_column_width=True)
|
490 |
+
else:
|
491 |
+
st.error(f"Image not found: {row['image_path']}")
|
492 |
+
continue
|
493 |
+
except Exception as e:
|
494 |
+
st.error(f"Error loading image: {e}")
|
495 |
+
continue
|
496 |
+
|
497 |
+
# Caption
|
498 |
+
st.markdown(f"**Caption:** {row['caption'][:100]}...")
|
499 |
+
|
500 |
+
# OCR Text Extraction
|
501 |
+
if enable_ocr and 'is_screenshot_with_text' in tasks:
|
502 |
+
with st.expander("🔍 Extracted Text (OCR)", expanded=False):
|
503 |
+
ocr_result = extract_text_from_image(row['image_path'])
|
504 |
+
if ocr_result['has_text']:
|
505 |
+
st.markdown(f"**Text Blocks Found:** {ocr_result['num_text_blocks']}")
|
506 |
+
st.text_area(
|
507 |
+
"Extracted Text:",
|
508 |
+
value=ocr_result['text'],
|
509 |
+
height=100,
|
510 |
+
key=f"ocr_text_{row['idx']}",
|
511 |
+
help="Text extracted from the image using OCR"
|
512 |
+
)
|
513 |
+
|
514 |
+
text_length = len(ocr_result['text'])
|
515 |
+
word_count = len(ocr_result['text'].split())
|
516 |
+
st.caption(f"📊 Text stats: {text_length} chars, {word_count} words")
|
517 |
+
|
518 |
+
if row['pred_is_screenshot_with_text'] == 'yes':
|
519 |
+
st.success("✅ Screenshot prediction correlates with text presence")
|
520 |
+
elif ocr_result['num_text_blocks'] > 5:
|
521 |
+
st.warning("⚠️ High text content but not predicted as screenshot")
|
522 |
+
else:
|
523 |
+
st.info("No text detected in this image")
|
524 |
+
if row['pred_is_screenshot_with_text'] == 'yes':
|
525 |
+
st.warning("⚠️ Predicted as screenshot but no text found")
|
526 |
+
|
527 |
+
# Predictions vs Ground Truth
|
528 |
+
st.markdown("**Predictions vs Ground Truth:**")
|
529 |
+
|
530 |
+
# Display all tasks dynamically
|
531 |
+
for task_key, task_info in tasks.items():
|
532 |
+
pred_val = row[f'pred_{task_key}']
|
533 |
+
gt_val = row[f'gt_{task_key}']
|
534 |
+
conf_val = row[f'conf_{task_key}']
|
535 |
+
correct = pred_val == gt_val
|
536 |
+
|
537 |
+
badge_class = "correct-prediction" if correct else "incorrect-prediction"
|
538 |
+
st.markdown(f"""
|
539 |
+
<div class="prediction-badge {badge_class}">
|
540 |
+
{task_info['name']}: {pred_val} | GT: {gt_val} | Conf: {conf_val:.3f}
|
541 |
+
</div>
|
542 |
+
""", unsafe_allow_html=True)
|
543 |
+
|
544 |
+
st.markdown("---")
|
545 |
+
|
546 |
+
if len(df) > max_images:
|
547 |
+
st.info(f"Showing {max_images} of {len(df)} filtered images. Use the slider above to show more images.")
|
548 |
+
|
549 |
+
def create_confusion_matrices(df, task_config):
|
550 |
+
"""Create confusion matrices for each classification task"""
|
551 |
+
st.markdown("## 📊 Model Performance Analysis")
|
552 |
+
|
553 |
+
tasks = {task['key']: task for task in task_config['tasks']}
|
554 |
+
binary_tasks = [t for t in tasks.values() if t['type'] == 'binary']
|
555 |
+
multi_class_tasks = [t for t in tasks.values() if t['type'] == 'multi_class']
|
556 |
+
|
557 |
+
tab1, tab2, tab3 = st.tabs(["Confusion Matrices", "Confidence Analysis", "Task Performance"])
|
558 |
+
|
559 |
+
with tab1:
|
560 |
+
# Binary classification confusion matrices
|
561 |
+
if binary_tasks:
|
562 |
+
st.markdown("### Binary Classification Tasks")
|
563 |
+
n_binary = len(binary_tasks)
|
564 |
+
n_cols = min(2, n_binary)
|
565 |
+
|
566 |
+
for i in range(0, n_binary, n_cols):
|
567 |
+
cols = st.columns(n_cols)
|
568 |
+
for j in range(n_cols):
|
569 |
+
if i + j < n_binary:
|
570 |
+
task = binary_tasks[i + j]
|
571 |
+
task_key = task['key']
|
572 |
+
|
573 |
+
with cols[j]:
|
574 |
+
confusion_data = pd.crosstab(
|
575 |
+
df[f'gt_{task_key}'],
|
576 |
+
df[f'pred_{task_key}'],
|
577 |
+
margins=True
|
578 |
+
)
|
579 |
+
st.markdown(f"**{task['name']} Confusion Matrix**")
|
580 |
+
st.dataframe(confusion_data, use_container_width=True)
|
581 |
+
|
582 |
+
# Multi-class confusion matrices
|
583 |
+
if multi_class_tasks:
|
584 |
+
st.markdown("### Multi-class Classification Tasks")
|
585 |
+
for task in multi_class_tasks:
|
586 |
+
task_key = task['key']
|
587 |
+
st.markdown(f"**{task['name']} Confusion Matrix**")
|
588 |
+
confusion_data = pd.crosstab(
|
589 |
+
df[f'gt_{task_key}'],
|
590 |
+
df[f'pred_{task_key}'],
|
591 |
+
margins=True
|
592 |
+
)
|
593 |
+
st.dataframe(confusion_data, use_container_width=True)
|
594 |
+
|
595 |
+
with tab2:
|
596 |
+
# Confidence analysis plots
|
597 |
+
fig1 = create_confidence_distribution_plot(df, task_config)
|
598 |
+
st.plotly_chart(fig1, use_container_width=True)
|
599 |
+
|
600 |
+
with tab3:
|
601 |
+
# Task-wise performance
|
602 |
+
st.markdown("**Performance by Task**")
|
603 |
+
|
604 |
+
performance_data = []
|
605 |
+
for task_key, task_info in tasks.items():
|
606 |
+
accuracy = df[f'correct_{task_key}'].mean()
|
607 |
+
confidence = df[f'conf_{task_key}'].mean()
|
608 |
+
|
609 |
+
performance_data.append({
|
610 |
+
'Task': task_info['name'],
|
611 |
+
'Type': task_info['type'],
|
612 |
+
'Accuracy': accuracy,
|
613 |
+
'Avg Confidence': confidence
|
614 |
+
})
|
615 |
+
|
616 |
+
performance_df = pd.DataFrame(performance_data)
|
617 |
+
st.dataframe(performance_df, use_container_width=True)
|
618 |
+
|
619 |
+
# Performance visualization
|
620 |
+
fig = px.scatter(performance_df, x='Avg Confidence', y='Accuracy',
|
621 |
+
color='Type', text='Task',
|
622 |
+
title="Task Performance: Accuracy vs Confidence")
|
623 |
+
fig.update_traces(textposition="top center")
|
624 |
+
st.plotly_chart(fig, use_container_width=True)
|
625 |
+
|
626 |
+
def main():
|
627 |
+
"""Main Streamlit application"""
|
628 |
+
# Header
|
629 |
+
st.markdown('<div class="main-header">🔍 Pseudoable Classifier Evaluation Dashboard</div>',
|
630 |
+
unsafe_allow_html=True)
|
631 |
+
|
632 |
+
# Load task configuration
|
633 |
+
task_config = load_task_config()
|
634 |
+
if not task_config:
|
635 |
+
st.error("❌ Could not load task configuration. Please ensure task_config.json exists.")
|
636 |
+
st.info("Expected location: ./task_config.json or ./checkpoints/task_config.json")
|
637 |
+
return
|
638 |
+
|
639 |
+
st.success(f"✅ Loaded task configuration with {len(task_config['tasks'])} tasks")
|
640 |
+
|
641 |
+
# Display task information
|
642 |
+
with st.expander("📋 Task Configuration", expanded=False):
|
643 |
+
for task in task_config['tasks']:
|
644 |
+
st.markdown(f"**{task['name']}** ({task['type']})")
|
645 |
+
st.markdown(f"- *Description:* {task['description']}")
|
646 |
+
st.markdown(f"- *Labels:* {', '.join(task['labels'])}")
|
647 |
+
st.markdown("---")
|
648 |
+
|
649 |
+
# Check and run validation if needed
|
650 |
+
if not run_validation_if_needed():
|
651 |
+
return
|
652 |
+
|
653 |
+
# Load validation results
|
654 |
+
validation_data = load_validation_results()
|
655 |
+
if not validation_data:
|
656 |
+
st.error("❌ Could not load validation results. Please check if validation_results.json exists.")
|
657 |
+
return
|
658 |
+
|
659 |
+
# Process data
|
660 |
+
df = process_validation_data(validation_data, task_config)
|
661 |
+
if df is None or len(df) == 0:
|
662 |
+
st.error("❌ No validation data found or data processing failed.")
|
663 |
+
return
|
664 |
+
|
665 |
+
# Show basic info
|
666 |
+
st.success(f"✅ Loaded {len(df)} validation samples successfully!")
|
667 |
+
|
668 |
+
# Overview metrics
|
669 |
+
create_overview_metrics(df, validation_data, task_config)
|
670 |
+
|
671 |
+
# Apply filters
|
672 |
+
filtered_df, ocr_filter, enable_ocr = apply_filters(df, task_config)
|
673 |
+
|
674 |
+
# Display results
|
675 |
+
if len(filtered_df) > 0:
|
676 |
+
# Performance analysis
|
677 |
+
create_confusion_matrices(filtered_df, task_config)
|
678 |
+
|
679 |
+
# Sample images
|
680 |
+
display_sample_images(filtered_df, task_config, ocr_filter, enable_ocr)
|
681 |
+
else:
|
682 |
+
st.warning("⚠️ No samples match the current filter criteria. Please adjust your filters.")
|
683 |
+
|
684 |
+
# Footer
|
685 |
+
st.markdown("---")
|
686 |
+
st.markdown("**📝 Instructions:**")
|
687 |
+
st.markdown("1. Use the sidebar to filter by task confidence and prediction classes")
|
688 |
+
st.markdown("2. Filter images by text content using OCR (if screenshot detection task is configured)")
|
689 |
+
st.markdown("3. Adjust the number of images to display and sorting order")
|
690 |
+
st.markdown("4. View model performance metrics and confusion matrices")
|
691 |
+
st.markdown("5. Browse sample images with predictions vs ground truth")
|
692 |
+
st.markdown("6. Green badges indicate correct predictions, red badges indicate incorrect predictions")
|
693 |
+
|
694 |
+
if __name__ == "__main__":
|
695 |
+
main()
|
upload_to_hf.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Upload trained multi-head SigLIP2 classifier to Hugging Face Hub (private).
|
4 |
+
|
5 |
+
This script will create/update the repo `fal/multihead_cls` and push:
|
6 |
+
- model weights: checkpoints/multi_head_siglip2_classifier.pth
|
7 |
+
- full training checkpoint: checkpoints/training_checkpoint.pth (optional)
|
8 |
+
- processor folder: checkpoints/siglip2_processor/
|
9 |
+
- README.md with usage
|
10 |
+
|
11 |
+
Auth: Set HUGGINGFACE_TOKEN environment variable or run `huggingface-cli login`.
|
12 |
+
"""
|
13 |
+
|
14 |
+
import os
|
15 |
+
import shutil
|
16 |
+
from pathlib import Path
|
17 |
+
from typing import Optional
|
18 |
+
|
19 |
+
from huggingface_hub import HfApi, HfFolder, create_repo, upload_folder, upload_file
|
20 |
+
|
21 |
+
|
22 |
+
REPO_ID = "fal/multihead_cls"
|
23 |
+
|
24 |
+
|
25 |
+
def ensure_logged_in() -> Optional[str]:
|
26 |
+
token = os.getenv("HUGGINGFACE_TOKEN") or HfFolder.get_token()
|
27 |
+
if not token:
|
28 |
+
raise RuntimeError(
|
29 |
+
"No Hugging Face token found. Set HUGGINGFACE_TOKEN or run `huggingface-cli login`."
|
30 |
+
)
|
31 |
+
return token
|
32 |
+
|
33 |
+
|
34 |
+
def prepare_staging_dir() -> Path:
|
35 |
+
root = Path(__file__).parent
|
36 |
+
ckpt_dir = root / "checkpoints"
|
37 |
+
if not ckpt_dir.exists():
|
38 |
+
raise FileNotFoundError("checkpoints/ directory not found. Train the model first.")
|
39 |
+
|
40 |
+
required = [
|
41 |
+
ckpt_dir / "multi_head_siglip2_classifier.pth",
|
42 |
+
ckpt_dir / "siglip2_processor",
|
43 |
+
]
|
44 |
+
for path in required:
|
45 |
+
if not path.exists():
|
46 |
+
raise FileNotFoundError(f"Missing required artifact: {path}")
|
47 |
+
|
48 |
+
# Check for task_config.json in checkpoints or root directory
|
49 |
+
task_config_path = ckpt_dir / "task_config.json"
|
50 |
+
if not task_config_path.exists():
|
51 |
+
task_config_path = root / "task_config.json"
|
52 |
+
if not task_config_path.exists():
|
53 |
+
raise FileNotFoundError("Missing required artifact: task_config.json (checked both checkpoints/ and root directory)")
|
54 |
+
|
55 |
+
staging = root / "hf_export"
|
56 |
+
if staging.exists():
|
57 |
+
shutil.rmtree(staging)
|
58 |
+
staging.mkdir(parents=True)
|
59 |
+
|
60 |
+
# Copy artifacts
|
61 |
+
shutil.copy2(ckpt_dir / "multi_head_siglip2_classifier.pth", staging / "model.pth")
|
62 |
+
shutil.copy2(task_config_path, staging / "task_config.json")
|
63 |
+
|
64 |
+
# Optional: training checkpoint and other metadata
|
65 |
+
train_ckpt = ckpt_dir / "training_checkpoint.pth"
|
66 |
+
if train_ckpt.exists():
|
67 |
+
shutil.copy2(train_ckpt, staging / "training_checkpoint.pth")
|
68 |
+
|
69 |
+
# Optional: training history and validation summary
|
70 |
+
for optional_file in ["training_history.json", "validation_summary.json", "stage4_summary.json"]:
|
71 |
+
optional_path = ckpt_dir / optional_file
|
72 |
+
if optional_path.exists():
|
73 |
+
shutil.copy2(optional_path, staging / optional_file)
|
74 |
+
|
75 |
+
# Processor
|
76 |
+
shutil.copytree(ckpt_dir / "siglip2_processor", staging / "processor")
|
77 |
+
|
78 |
+
# Add example and README if present
|
79 |
+
readme_src = root / "README.md"
|
80 |
+
if readme_src.exists():
|
81 |
+
shutil.copy2(readme_src, staging / "README.md")
|
82 |
+
example_src = root / "example.py"
|
83 |
+
if example_src.exists():
|
84 |
+
shutil.copy2(example_src, staging / "example.py")
|
85 |
+
|
86 |
+
return staging
|
87 |
+
|
88 |
+
|
89 |
+
def upload_to_hub(private: bool = True) -> None:
|
90 |
+
token = ensure_logged_in()
|
91 |
+
api = HfApi(token=token)
|
92 |
+
|
93 |
+
create_repo(REPO_ID, private=private, repo_type="model", exist_ok=True, token=token)
|
94 |
+
|
95 |
+
staging = prepare_staging_dir()
|
96 |
+
|
97 |
+
# Upload all files in staging
|
98 |
+
upload_folder(
|
99 |
+
folder_path=str(staging),
|
100 |
+
repo_id=REPO_ID,
|
101 |
+
repo_type="model",
|
102 |
+
commit_message="Upload multi-head SigLIP2 classifier with dynamic task configuration",
|
103 |
+
token=token,
|
104 |
+
)
|
105 |
+
|
106 |
+
print(f"Uploaded to https://huggingface.co/{REPO_ID} (private={private})")
|
107 |
+
|
108 |
+
|
109 |
+
if __name__ == "__main__":
|
110 |
+
upload_to_hub(private=True)
|
111 |
+
|
112 |
+
|
validation_runner.py
ADDED
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Validation Runner: Runs trained model on validation set and saves predictions
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import json
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
from pathlib import Path
|
12 |
+
import logging
|
13 |
+
from transformers import AutoProcessor
|
14 |
+
from stage_4 import MultiHeadSiglipClassifier, CKPT, load_task_config
|
15 |
+
import pandas as pd
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
# Set up logging
|
19 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
def _is_incomplete_classification(classification: dict, task_config: dict) -> bool:
|
23 |
+
"""Check if classification contains incomplete data (empty or '...' values)"""
|
24 |
+
if not task_config or 'tasks' not in task_config:
|
25 |
+
return True
|
26 |
+
|
27 |
+
required_tasks = [task['key'] for task in task_config['tasks']]
|
28 |
+
|
29 |
+
for task_key in required_tasks:
|
30 |
+
if task_key not in classification:
|
31 |
+
return True
|
32 |
+
|
33 |
+
value = classification[task_key]
|
34 |
+
# Check for incomplete markers
|
35 |
+
if not value or value == "..." or value == "" or value is None:
|
36 |
+
return True
|
37 |
+
|
38 |
+
return False
|
39 |
+
|
40 |
+
def load_trained_model(checkpoint_dir: str = './checkpoints'):
|
41 |
+
"""Load the trained model and processor"""
|
42 |
+
checkpoint_path = Path(checkpoint_dir)
|
43 |
+
|
44 |
+
# Load task configuration
|
45 |
+
task_config_path = checkpoint_path / 'task_config.json'
|
46 |
+
if not task_config_path.exists():
|
47 |
+
# Fallback to root directory
|
48 |
+
task_config_path = './task_config.json'
|
49 |
+
|
50 |
+
task_config = load_task_config(str(task_config_path))
|
51 |
+
|
52 |
+
# Load processor
|
53 |
+
processor = AutoProcessor.from_pretrained(CKPT)
|
54 |
+
|
55 |
+
# Load model with task config
|
56 |
+
model = MultiHeadSiglipClassifier(task_config)
|
57 |
+
model_state = torch.load(checkpoint_path / 'multi_head_siglip2_classifier.pth', map_location='cpu')
|
58 |
+
model.load_state_dict(model_state)
|
59 |
+
|
60 |
+
# Set to evaluation mode
|
61 |
+
model.eval()
|
62 |
+
|
63 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
64 |
+
model.to(device)
|
65 |
+
|
66 |
+
logger.info(f"Model loaded successfully on device: {device}")
|
67 |
+
return model, processor, device, task_config
|
68 |
+
|
69 |
+
def load_validation_data(data_dir: str = './data', task_config: dict = None):
|
70 |
+
"""Load validation samples from stage 2 metadata files"""
|
71 |
+
data_path = Path(data_dir)
|
72 |
+
|
73 |
+
# Load from stage 2 metadata files
|
74 |
+
metadata_dir = data_path / 'metadata'
|
75 |
+
if not metadata_dir.exists():
|
76 |
+
logger.error("Metadata directory not found. Run stages 1 and 2 first.")
|
77 |
+
return []
|
78 |
+
|
79 |
+
metadata_files = list(metadata_dir.glob('meta_*_stage2.json'))
|
80 |
+
if not metadata_files:
|
81 |
+
logger.error("No stage 2 metadata files found. Run stage 2 first.")
|
82 |
+
return []
|
83 |
+
|
84 |
+
samples = []
|
85 |
+
skipped_incomplete = 0
|
86 |
+
|
87 |
+
for meta_file in tqdm(metadata_files, desc="Loading validation data"):
|
88 |
+
try:
|
89 |
+
with open(meta_file, 'r') as f:
|
90 |
+
metadata = json.load(f)
|
91 |
+
|
92 |
+
# Check if classification is complete
|
93 |
+
if not metadata.get('stage2_complete', False):
|
94 |
+
logger.warning(f"Skipping {meta_file} - classification not complete")
|
95 |
+
skipped_incomplete += 1
|
96 |
+
continue
|
97 |
+
|
98 |
+
# Check if classification contains incomplete data
|
99 |
+
classification = metadata.get('classification', {})
|
100 |
+
if not classification or _is_incomplete_classification(classification, task_config):
|
101 |
+
logger.warning(f"Skipping {meta_file} - incomplete classification data")
|
102 |
+
skipped_incomplete += 1
|
103 |
+
continue
|
104 |
+
|
105 |
+
# Check if image exists
|
106 |
+
image_path = metadata['image_path']
|
107 |
+
if not os.path.exists(image_path):
|
108 |
+
logger.warning(f"Image not found: {image_path}")
|
109 |
+
skipped_incomplete += 1
|
110 |
+
continue
|
111 |
+
|
112 |
+
samples.append({
|
113 |
+
'idx': metadata['idx'],
|
114 |
+
'image_path': metadata['image_path'],
|
115 |
+
'caption': metadata['caption'],
|
116 |
+
'url': metadata['url'],
|
117 |
+
'hash': metadata['hash'],
|
118 |
+
'ground_truth': metadata['classification']
|
119 |
+
})
|
120 |
+
|
121 |
+
except Exception as e:
|
122 |
+
logger.warning(f"Error loading {meta_file}: {e}")
|
123 |
+
skipped_incomplete += 1
|
124 |
+
|
125 |
+
if skipped_incomplete > 0:
|
126 |
+
logger.warning(f"Skipped {skipped_incomplete} incomplete samples")
|
127 |
+
logger.info(f"Loaded {len(samples)} valid validation samples")
|
128 |
+
return samples
|
129 |
+
|
130 |
+
def predict_batch(model, processor, images, device, task_config, batch_size=8):
|
131 |
+
"""Run predictions on a batch of images"""
|
132 |
+
predictions = []
|
133 |
+
tasks = {task['key']: task for task in task_config['tasks']}
|
134 |
+
|
135 |
+
for i in range(0, len(images), batch_size):
|
136 |
+
batch_images = images[i:i+batch_size]
|
137 |
+
|
138 |
+
# Process images
|
139 |
+
inputs = processor(images=batch_images, return_tensors="pt")
|
140 |
+
pixel_values = inputs['pixel_values'].to(device)
|
141 |
+
|
142 |
+
with torch.no_grad():
|
143 |
+
outputs = model(pixel_values)
|
144 |
+
|
145 |
+
# Convert outputs to probabilities and predictions
|
146 |
+
batch_preds = []
|
147 |
+
for j in range(len(batch_images)):
|
148 |
+
pred = {}
|
149 |
+
|
150 |
+
# Process each task dynamically
|
151 |
+
for task_key, task_info in tasks.items():
|
152 |
+
logits = outputs[task_key][j]
|
153 |
+
probs = torch.softmax(logits, dim=0)
|
154 |
+
pred_class = torch.argmax(logits).item()
|
155 |
+
confidence = probs[pred_class].item()
|
156 |
+
|
157 |
+
if task_info['type'] == 'binary':
|
158 |
+
# Binary classification
|
159 |
+
pred[f'{task_key}_prediction'] = 'yes' if pred_class == 1 else 'no'
|
160 |
+
pred[f'{task_key}_confidence'] = confidence
|
161 |
+
pred[f'{task_key}_prob_yes'] = probs[1].item()
|
162 |
+
pred[f'{task_key}_prob_no'] = probs[0].item()
|
163 |
+
|
164 |
+
elif task_info['type'] == 'multi_class':
|
165 |
+
# Multi-class classification
|
166 |
+
pred_label = task_info['labels'][pred_class]
|
167 |
+
pred[f'{task_key}_prediction'] = pred_label
|
168 |
+
pred[f'{task_key}_confidence'] = confidence
|
169 |
+
|
170 |
+
# Add probabilities for all classes
|
171 |
+
for idx, label in enumerate(task_info['labels']):
|
172 |
+
pred[f'{task_key}_prob_{label}'] = probs[idx].item()
|
173 |
+
|
174 |
+
batch_preds.append(pred)
|
175 |
+
|
176 |
+
predictions.extend(batch_preds)
|
177 |
+
|
178 |
+
return predictions
|
179 |
+
|
180 |
+
def calculate_accuracies(predictions, ground_truths, task_config):
|
181 |
+
"""Calculate accuracies for each task"""
|
182 |
+
accuracies = {}
|
183 |
+
tasks = {task['key']: task for task in task_config['tasks']}
|
184 |
+
|
185 |
+
for task_key, task_info in tasks.items():
|
186 |
+
pred_key = f'{task_key}_prediction'
|
187 |
+
|
188 |
+
correct = sum(1 for pred, gt in zip(predictions, ground_truths)
|
189 |
+
if pred[pred_key] == gt[task_key])
|
190 |
+
total = len(predictions)
|
191 |
+
accuracies[f'{task_key}_accuracy'] = correct / total if total > 0 else 0
|
192 |
+
|
193 |
+
return accuracies
|
194 |
+
|
195 |
+
def run_validation(data_dir: str = './data', checkpoint_dir: str = './checkpoints',
|
196 |
+
output_file: str = './validation_results.json'):
|
197 |
+
"""Run complete validation and save results"""
|
198 |
+
logger.info("Starting validation run...")
|
199 |
+
|
200 |
+
# Load model and data
|
201 |
+
model, processor, device, task_config = load_trained_model(checkpoint_dir)
|
202 |
+
samples = load_validation_data(data_dir, task_config)
|
203 |
+
|
204 |
+
if not samples:
|
205 |
+
logger.error("No validation samples found!")
|
206 |
+
return
|
207 |
+
|
208 |
+
# Prepare images for batch processing
|
209 |
+
images = []
|
210 |
+
for sample in tqdm(samples, desc="Loading images"):
|
211 |
+
try:
|
212 |
+
img = Image.open(sample['image_path']).convert('RGB')
|
213 |
+
images.append(img)
|
214 |
+
except Exception as e:
|
215 |
+
logger.error(f"Error loading image {sample['image_path']}: {e}")
|
216 |
+
images.append(None)
|
217 |
+
|
218 |
+
# Filter out failed images
|
219 |
+
valid_samples = []
|
220 |
+
valid_images = []
|
221 |
+
for sample, img in zip(samples, images):
|
222 |
+
if img is not None:
|
223 |
+
valid_samples.append(sample)
|
224 |
+
valid_images.append(img)
|
225 |
+
|
226 |
+
logger.info(f"Running predictions on {len(valid_samples)} valid samples...")
|
227 |
+
|
228 |
+
# Run predictions
|
229 |
+
predictions = predict_batch(model, processor, valid_images, device, task_config)
|
230 |
+
|
231 |
+
# Calculate accuracies
|
232 |
+
ground_truths = [sample['ground_truth'] for sample in valid_samples]
|
233 |
+
accuracies = calculate_accuracies(predictions, ground_truths, task_config)
|
234 |
+
|
235 |
+
# Combine results
|
236 |
+
validation_results = []
|
237 |
+
for sample, prediction in zip(valid_samples, predictions):
|
238 |
+
result = {
|
239 |
+
**sample,
|
240 |
+
'predictions': prediction
|
241 |
+
}
|
242 |
+
validation_results.append(result)
|
243 |
+
|
244 |
+
# Create final output
|
245 |
+
output_data = {
|
246 |
+
'metadata': {
|
247 |
+
'total_samples': len(validation_results),
|
248 |
+
'model_checkpoint': checkpoint_dir,
|
249 |
+
'validation_accuracies': accuracies,
|
250 |
+
'task_config': task_config,
|
251 |
+
'timestamp': pd.Timestamp.now().isoformat()
|
252 |
+
},
|
253 |
+
'results': validation_results
|
254 |
+
}
|
255 |
+
|
256 |
+
# Save results
|
257 |
+
output_path = Path(output_file)
|
258 |
+
with open(output_path, 'w') as f:
|
259 |
+
json.dump(output_data, f, indent=2)
|
260 |
+
|
261 |
+
logger.info(f"Validation results saved to {output_path}")
|
262 |
+
logger.info("Validation Accuracies:")
|
263 |
+
for key, value in accuracies.items():
|
264 |
+
logger.info(f" {key}: {value:.4f}")
|
265 |
+
|
266 |
+
return output_data
|
267 |
+
|
268 |
+
def main():
|
269 |
+
"""Main execution"""
|
270 |
+
logger.info("Starting validation runner...")
|
271 |
+
|
272 |
+
# Check if model exists
|
273 |
+
if not Path('./checkpoints/multi_head_siglip2_classifier.pth').exists():
|
274 |
+
logger.error("Trained model not found! Run stage 4 first.")
|
275 |
+
return
|
276 |
+
|
277 |
+
# Run validation
|
278 |
+
results = run_validation()
|
279 |
+
|
280 |
+
if results:
|
281 |
+
logger.info("Validation completed successfully!")
|
282 |
+
else:
|
283 |
+
logger.error("Validation failed!")
|
284 |
+
|
285 |
+
if __name__ == "__main__":
|
286 |
+
main()
|