Spaces:
Runtime error
Runtime error
import json | |
import openai | |
import os | |
from datetime import datetime | |
import base64 | |
import logging | |
from pathlib import Path | |
import time | |
from tqdm import tqdm | |
from typing import Dict, List, Optional, Union, Any | |
# Configuration constants | |
DEBUG_MODE = False | |
OUTPUT_DIR = "results" | |
MODEL_NAME = "gpt-4o-2024-05-13" | |
TEMPERATURE = 0.2 | |
SUBSET = "Visual Question Answering" | |
# Set up logging configuration | |
logging_level = logging.DEBUG if DEBUG_MODE else logging.INFO | |
logging.basicConfig(level=logging_level, format="%(asctime)s - %(levelname)s - %(message)s") | |
logger = logging.getLogger(__name__) | |
def get_mime_type(file_path: str) -> str: | |
""" | |
Determine MIME type based on file extension. | |
Args: | |
file_path (str): Path to the file | |
Returns: | |
str: MIME type string for the file | |
""" | |
extension = os.path.splitext(file_path)[1].lower() | |
mime_types = { | |
".png": "image/png", | |
".jpg": "image/jpeg", | |
".jpeg": "image/jpeg", | |
".gif": "image/gif", | |
} | |
return mime_types.get(extension, "application/octet-stream") | |
def encode_image(image_path: str) -> str: | |
""" | |
Encode image to base64 with extensive error checking. | |
Args: | |
image_path (str): Path to the image file | |
Returns: | |
str: Base64 encoded image string | |
Raises: | |
FileNotFoundError: If image file does not exist | |
ValueError: If image file is empty or too large | |
Exception: For other image processing errors | |
""" | |
logger.debug(f"Attempting to read image from: {image_path}") | |
if not os.path.exists(image_path): | |
raise FileNotFoundError(f"Image file not found: {image_path}") | |
# Add check for file size | |
file_size = os.path.getsize(image_path) | |
if file_size > 20 * 1024 * 1024: # 20MB limit | |
raise ValueError("Image file size exceeds 20MB limit") | |
if file_size == 0: | |
raise ValueError("Image file is empty") | |
logger.debug(f"Image file size: {file_size / 1024:.2f} KB") | |
try: | |
from PIL import Image | |
# Try to open and verify the image | |
with Image.open(image_path) as img: | |
# Get image details | |
width, height = img.size | |
format = img.format | |
mode = img.mode | |
logger.debug( | |
f"Image verification - Format: {format}, Size: {width}x{height}, Mode: {mode}" | |
) | |
if format not in ["PNG", "JPEG", "GIF"]: | |
raise ValueError(f"Unsupported image format: {format}") | |
with open(image_path, "rb") as image_file: | |
# Read the first few bytes to verify it's a valid PNG | |
header = image_file.read(8) | |
# if header != b'\x89PNG\r\n\x1a\n': | |
# logger.warning("File does not have a valid PNG signature") | |
# Reset file pointer and read entire file | |
image_file.seek(0) | |
encoded = base64.b64encode(image_file.read()).decode("utf-8") | |
encoded_length = len(encoded) | |
logger.debug(f"Base64 encoded length: {encoded_length} characters") | |
# Verify the encoded string is not empty and starts correctly | |
if encoded_length == 0: | |
raise ValueError("Base64 encoding produced empty string") | |
if not encoded.startswith("/9j/") and not encoded.startswith("iVBOR"): | |
logger.warning("Base64 string doesn't start with expected JPEG or PNG header") | |
return encoded | |
except Exception as e: | |
logger.error(f"Error reading/encoding image: {str(e)}") | |
raise | |
def create_single_request( | |
image_path: str, question: str, options: Dict[str, str] | |
) -> List[Dict[str, Any]]: | |
""" | |
Create a single API request with image and question. | |
Args: | |
image_path (str): Path to the image file | |
question (str): Question text | |
options (Dict[str, str]): Dictionary containing options with keys 'option_0' and 'option_1' | |
Returns: | |
List[Dict[str, Any]]: List of message dictionaries for the API request | |
Raises: | |
Exception: For errors in request creation | |
""" | |
if DEBUG_MODE: | |
logger.debug("Creating API request...") | |
prompt = f"""Given the following medical examination question: | |
Please answer this multiple choice question: | |
Question: {question} | |
Options: | |
A) {options['option_0']} | |
B) {options['option_1']} | |
Base your answer only on the provided image and select either A or B.""" | |
try: | |
encoded_image = encode_image(image_path) | |
mime_type = get_mime_type(image_path) | |
if DEBUG_MODE: | |
logger.debug(f"Image encoded with MIME type: {mime_type}") | |
messages = [ | |
{ | |
"role": "system", | |
"content": "You are taking a medical exam. Answer ONLY with the letter (A/B) corresponding to your answer.", | |
}, | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": prompt}, | |
{ | |
"type": "image_url", | |
"image_url": {"url": f"data:{mime_type};base64,{encoded_image}"}, | |
}, | |
], | |
}, | |
] | |
if DEBUG_MODE: | |
log_messages = json.loads(json.dumps(messages)) | |
log_messages[1]["content"][1]["image_url"][ | |
"url" | |
] = f"data:{mime_type};base64,[BASE64_IMAGE_TRUNCATED]" | |
logger.debug(f"Complete API request payload:\n{json.dumps(log_messages, indent=2)}") | |
return messages | |
except Exception as e: | |
logger.error(f"Error creating request: {str(e)}") | |
raise | |
def check_answer(model_answer: str, correct_answer: int) -> bool: | |
""" | |
Check if the model's answer matches the correct answer. | |
Args: | |
model_answer (str): The model's answer (A or B) | |
correct_answer (int): The correct answer index (0 for A, 1 for B) | |
Returns: | |
bool: True if answer is correct, False otherwise | |
""" | |
if not isinstance(model_answer, str): | |
return False | |
# Clean the model answer to get just the letter | |
model_letter = model_answer.strip().upper() | |
if model_letter.startswith("A"): | |
model_index = 0 | |
elif model_letter.startswith("B"): | |
model_index = 1 | |
else: | |
return False | |
return model_index == correct_answer | |
def save_results_to_json(results: List[Dict[str, Any]], output_dir: str) -> str: | |
""" | |
Save results to a JSON file with timestamp. | |
Args: | |
results (List[Dict[str, Any]]): List of result dictionaries | |
output_dir (str): Directory to save results | |
Returns: | |
str: Path to the saved file | |
""" | |
Path(output_dir).mkdir(parents=True, exist_ok=True) | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
output_file = os.path.join(output_dir, f"batch_results_{timestamp}.json") | |
with open(output_file, "w") as f: | |
json.dump(results, f, indent=2) | |
logger.info(f"Batch results saved to {output_file}") | |
return output_file | |
def calculate_accuracy(results: List[Dict[str, Any]]) -> tuple[float, int, int]: | |
""" | |
Calculate accuracy from results, handling error cases. | |
Args: | |
results (List[Dict[str, Any]]): List of result dictionaries | |
Returns: | |
tuple[float, int, int]: Tuple containing (accuracy percentage, number correct, total) | |
""" | |
if not results: | |
return 0.0, 0, 0 | |
total = len(results) | |
valid_results = [r for r in results if "output" in r] | |
correct = sum( | |
1 for result in valid_results if result.get("output", {}).get("is_correct", False) | |
) | |
accuracy = (correct / total * 100) if total > 0 else 0 | |
return accuracy, correct, total | |
def calculate_batch_accuracy(results: List[Dict[str, Any]]) -> float: | |
""" | |
Calculate accuracy for the current batch. | |
Args: | |
results (List[Dict[str, Any]]): List of result dictionaries | |
Returns: | |
float: Accuracy percentage for the batch | |
""" | |
valid_results = [r for r in results if "output" in r] | |
if not valid_results: | |
return 0.0 | |
return sum(1 for r in valid_results if r["output"]["is_correct"]) / len(valid_results) * 100 | |
def process_batch( | |
data: List[Dict[str, Any]], client: openai.OpenAI, start_idx: int = 0, batch_size: int = 50 | |
) -> List[Dict[str, Any]]: | |
""" | |
Process a batch of examples and return results. | |
Args: | |
data (List[Dict[str, Any]]): List of data items to process | |
client (openai.OpenAI): OpenAI client instance | |
start_idx (int, optional): Starting index for batch. Defaults to 0 | |
batch_size (int, optional): Size of batch to process. Defaults to 50 | |
Returns: | |
List[Dict[str, Any]]: List of processed results | |
""" | |
batch_results = [] | |
end_idx = min(start_idx + batch_size, len(data)) | |
pbar = tqdm( | |
range(start_idx, end_idx), | |
desc=f"Processing batch {start_idx//batch_size + 1}", | |
unit="example", | |
) | |
for index in pbar: | |
vqa_item = data[index] | |
options = {"option_0": vqa_item["option_0"], "option_1": vqa_item["option_1"]} | |
try: | |
messages = create_single_request( | |
image_path=vqa_item["image_path"], question=vqa_item["question"], options=options | |
) | |
response = client.chat.completions.create( | |
model=MODEL_NAME, messages=messages, max_tokens=50, temperature=TEMPERATURE | |
) | |
model_answer = response.choices[0].message.content.strip() | |
is_correct = check_answer(model_answer, vqa_item["answer"]) | |
result = { | |
"timestamp": datetime.now().isoformat(), | |
"example_index": index, | |
"input": { | |
"question": vqa_item["question"], | |
"options": {"A": vqa_item["option_0"], "B": vqa_item["option_1"]}, | |
"image_path": vqa_item["image_path"], | |
}, | |
"output": { | |
"model_answer": model_answer, | |
"correct_answer": "A" if vqa_item["answer"] == 0 else "B", | |
"is_correct": is_correct, | |
"usage": { | |
"prompt_tokens": response.usage.prompt_tokens, | |
"completion_tokens": response.usage.completion_tokens, | |
"total_tokens": response.usage.total_tokens, | |
}, | |
}, | |
} | |
batch_results.append(result) | |
# Update progress bar with current accuracy | |
current_accuracy = calculate_batch_accuracy(batch_results) | |
pbar.set_description( | |
f"Batch {start_idx//batch_size + 1} - Accuracy: {current_accuracy:.2f}% " | |
f"({len(batch_results)}/{index-start_idx+1} examples)" | |
) | |
except Exception as e: | |
error_result = { | |
"timestamp": datetime.now().isoformat(), | |
"example_index": index, | |
"error": str(e), | |
"input": { | |
"question": vqa_item["question"], | |
"options": {"A": vqa_item["option_0"], "B": vqa_item["option_1"]}, | |
"image_path": vqa_item["image_path"], | |
}, | |
} | |
batch_results.append(error_result) | |
if DEBUG_MODE: | |
pbar.write(f"Error processing example {index}: {str(e)}") | |
time.sleep(1) # Rate limiting | |
return batch_results | |
def main() -> None: | |
""" | |
Main function to process the entire dataset. | |
Raises: | |
ValueError: If OPENAI_API_KEY is not set | |
Exception: For other processing errors | |
""" | |
logger.info("Starting full dataset processing...") | |
json_path = "../data/chexbench_updated.json" | |
try: | |
api_key = os.getenv("OPENAI_API_KEY") | |
if not api_key: | |
raise ValueError("OPENAI_API_KEY environment variable is not set.") | |
client = openai.OpenAI(api_key=api_key) | |
with open(json_path, "r") as f: | |
data = json.load(f) | |
subset_data = data[SUBSET] | |
total_examples = len(subset_data) | |
logger.info(f"Found {total_examples} examples in {SUBSET} subset") | |
all_results = [] | |
batch_size = 50 # Process in batches of 50 examples | |
# Process all examples in batches | |
for start_idx in range(0, total_examples, batch_size): | |
batch_results = process_batch(subset_data, client, start_idx, batch_size) | |
all_results.extend(batch_results) | |
# Save intermediate results after each batch | |
output_file = save_results_to_json(all_results, OUTPUT_DIR) | |
# Calculate and log overall progress | |
overall_accuracy, correct, total = calculate_accuracy(all_results) | |
logger.info(f"Overall Progress: {len(all_results)}/{total_examples} examples processed") | |
logger.info(f"Current Accuracy: {overall_accuracy:.2f}% ({correct}/{total} correct)") | |
logger.info("Processing completed!") | |
logger.info(f"Final results saved to: {output_file}") | |
except Exception as e: | |
logger.error(f"Fatal error: {str(e)}") | |
raise | |
if __name__ == "__main__": | |
main() | |