import json import openai import os from datetime import datetime import base64 import logging from pathlib import Path import time from tqdm import tqdm from typing import Dict, List, Optional, Union, Any # Configuration constants DEBUG_MODE = False OUTPUT_DIR = "results" MODEL_NAME = "gpt-4o-2024-05-13" TEMPERATURE = 0.2 SUBSET = "Visual Question Answering" # Set up logging configuration logging_level = logging.DEBUG if DEBUG_MODE else logging.INFO logging.basicConfig(level=logging_level, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) def get_mime_type(file_path: str) -> str: """ Determine MIME type based on file extension. Args: file_path (str): Path to the file Returns: str: MIME type string for the file """ extension = os.path.splitext(file_path)[1].lower() mime_types = { ".png": "image/png", ".jpg": "image/jpeg", ".jpeg": "image/jpeg", ".gif": "image/gif", } return mime_types.get(extension, "application/octet-stream") def encode_image(image_path: str) -> str: """ Encode image to base64 with extensive error checking. Args: image_path (str): Path to the image file Returns: str: Base64 encoded image string Raises: FileNotFoundError: If image file does not exist ValueError: If image file is empty or too large Exception: For other image processing errors """ logger.debug(f"Attempting to read image from: {image_path}") if not os.path.exists(image_path): raise FileNotFoundError(f"Image file not found: {image_path}") # Add check for file size file_size = os.path.getsize(image_path) if file_size > 20 * 1024 * 1024: # 20MB limit raise ValueError("Image file size exceeds 20MB limit") if file_size == 0: raise ValueError("Image file is empty") logger.debug(f"Image file size: {file_size / 1024:.2f} KB") try: from PIL import Image # Try to open and verify the image with Image.open(image_path) as img: # Get image details width, height = img.size format = img.format mode = img.mode logger.debug( f"Image verification - Format: {format}, Size: {width}x{height}, Mode: {mode}" ) if format not in ["PNG", "JPEG", "GIF"]: raise ValueError(f"Unsupported image format: {format}") with open(image_path, "rb") as image_file: # Read the first few bytes to verify it's a valid PNG header = image_file.read(8) # if header != b'\x89PNG\r\n\x1a\n': # logger.warning("File does not have a valid PNG signature") # Reset file pointer and read entire file image_file.seek(0) encoded = base64.b64encode(image_file.read()).decode("utf-8") encoded_length = len(encoded) logger.debug(f"Base64 encoded length: {encoded_length} characters") # Verify the encoded string is not empty and starts correctly if encoded_length == 0: raise ValueError("Base64 encoding produced empty string") if not encoded.startswith("/9j/") and not encoded.startswith("iVBOR"): logger.warning("Base64 string doesn't start with expected JPEG or PNG header") return encoded except Exception as e: logger.error(f"Error reading/encoding image: {str(e)}") raise def create_single_request( image_path: str, question: str, options: Dict[str, str] ) -> List[Dict[str, Any]]: """ Create a single API request with image and question. Args: image_path (str): Path to the image file question (str): Question text options (Dict[str, str]): Dictionary containing options with keys 'option_0' and 'option_1' Returns: List[Dict[str, Any]]: List of message dictionaries for the API request Raises: Exception: For errors in request creation """ if DEBUG_MODE: logger.debug("Creating API request...") prompt = f"""Given the following medical examination question: Please answer this multiple choice question: Question: {question} Options: A) {options['option_0']} B) {options['option_1']} Base your answer only on the provided image and select either A or B.""" try: encoded_image = encode_image(image_path) mime_type = get_mime_type(image_path) if DEBUG_MODE: logger.debug(f"Image encoded with MIME type: {mime_type}") messages = [ { "role": "system", "content": "You are taking a medical exam. Answer ONLY with the letter (A/B) corresponding to your answer.", }, { "role": "user", "content": [ {"type": "text", "text": prompt}, { "type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{encoded_image}"}, }, ], }, ] if DEBUG_MODE: log_messages = json.loads(json.dumps(messages)) log_messages[1]["content"][1]["image_url"][ "url" ] = f"data:{mime_type};base64,[BASE64_IMAGE_TRUNCATED]" logger.debug(f"Complete API request payload:\n{json.dumps(log_messages, indent=2)}") return messages except Exception as e: logger.error(f"Error creating request: {str(e)}") raise def check_answer(model_answer: str, correct_answer: int) -> bool: """ Check if the model's answer matches the correct answer. Args: model_answer (str): The model's answer (A or B) correct_answer (int): The correct answer index (0 for A, 1 for B) Returns: bool: True if answer is correct, False otherwise """ if not isinstance(model_answer, str): return False # Clean the model answer to get just the letter model_letter = model_answer.strip().upper() if model_letter.startswith("A"): model_index = 0 elif model_letter.startswith("B"): model_index = 1 else: return False return model_index == correct_answer def save_results_to_json(results: List[Dict[str, Any]], output_dir: str) -> str: """ Save results to a JSON file with timestamp. Args: results (List[Dict[str, Any]]): List of result dictionaries output_dir (str): Directory to save results Returns: str: Path to the saved file """ Path(output_dir).mkdir(parents=True, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") output_file = os.path.join(output_dir, f"batch_results_{timestamp}.json") with open(output_file, "w") as f: json.dump(results, f, indent=2) logger.info(f"Batch results saved to {output_file}") return output_file def calculate_accuracy(results: List[Dict[str, Any]]) -> tuple[float, int, int]: """ Calculate accuracy from results, handling error cases. Args: results (List[Dict[str, Any]]): List of result dictionaries Returns: tuple[float, int, int]: Tuple containing (accuracy percentage, number correct, total) """ if not results: return 0.0, 0, 0 total = len(results) valid_results = [r for r in results if "output" in r] correct = sum( 1 for result in valid_results if result.get("output", {}).get("is_correct", False) ) accuracy = (correct / total * 100) if total > 0 else 0 return accuracy, correct, total def calculate_batch_accuracy(results: List[Dict[str, Any]]) -> float: """ Calculate accuracy for the current batch. Args: results (List[Dict[str, Any]]): List of result dictionaries Returns: float: Accuracy percentage for the batch """ valid_results = [r for r in results if "output" in r] if not valid_results: return 0.0 return sum(1 for r in valid_results if r["output"]["is_correct"]) / len(valid_results) * 100 def process_batch( data: List[Dict[str, Any]], client: openai.OpenAI, start_idx: int = 0, batch_size: int = 50 ) -> List[Dict[str, Any]]: """ Process a batch of examples and return results. Args: data (List[Dict[str, Any]]): List of data items to process client (openai.OpenAI): OpenAI client instance start_idx (int, optional): Starting index for batch. Defaults to 0 batch_size (int, optional): Size of batch to process. Defaults to 50 Returns: List[Dict[str, Any]]: List of processed results """ batch_results = [] end_idx = min(start_idx + batch_size, len(data)) pbar = tqdm( range(start_idx, end_idx), desc=f"Processing batch {start_idx//batch_size + 1}", unit="example", ) for index in pbar: vqa_item = data[index] options = {"option_0": vqa_item["option_0"], "option_1": vqa_item["option_1"]} try: messages = create_single_request( image_path=vqa_item["image_path"], question=vqa_item["question"], options=options ) response = client.chat.completions.create( model=MODEL_NAME, messages=messages, max_tokens=50, temperature=TEMPERATURE ) model_answer = response.choices[0].message.content.strip() is_correct = check_answer(model_answer, vqa_item["answer"]) result = { "timestamp": datetime.now().isoformat(), "example_index": index, "input": { "question": vqa_item["question"], "options": {"A": vqa_item["option_0"], "B": vqa_item["option_1"]}, "image_path": vqa_item["image_path"], }, "output": { "model_answer": model_answer, "correct_answer": "A" if vqa_item["answer"] == 0 else "B", "is_correct": is_correct, "usage": { "prompt_tokens": response.usage.prompt_tokens, "completion_tokens": response.usage.completion_tokens, "total_tokens": response.usage.total_tokens, }, }, } batch_results.append(result) # Update progress bar with current accuracy current_accuracy = calculate_batch_accuracy(batch_results) pbar.set_description( f"Batch {start_idx//batch_size + 1} - Accuracy: {current_accuracy:.2f}% " f"({len(batch_results)}/{index-start_idx+1} examples)" ) except Exception as e: error_result = { "timestamp": datetime.now().isoformat(), "example_index": index, "error": str(e), "input": { "question": vqa_item["question"], "options": {"A": vqa_item["option_0"], "B": vqa_item["option_1"]}, "image_path": vqa_item["image_path"], }, } batch_results.append(error_result) if DEBUG_MODE: pbar.write(f"Error processing example {index}: {str(e)}") time.sleep(1) # Rate limiting return batch_results def main() -> None: """ Main function to process the entire dataset. Raises: ValueError: If OPENAI_API_KEY is not set Exception: For other processing errors """ logger.info("Starting full dataset processing...") json_path = "../data/chexbench_updated.json" try: api_key = os.getenv("OPENAI_API_KEY") if not api_key: raise ValueError("OPENAI_API_KEY environment variable is not set.") client = openai.OpenAI(api_key=api_key) with open(json_path, "r") as f: data = json.load(f) subset_data = data[SUBSET] total_examples = len(subset_data) logger.info(f"Found {total_examples} examples in {SUBSET} subset") all_results = [] batch_size = 50 # Process in batches of 50 examples # Process all examples in batches for start_idx in range(0, total_examples, batch_size): batch_results = process_batch(subset_data, client, start_idx, batch_size) all_results.extend(batch_results) # Save intermediate results after each batch output_file = save_results_to_json(all_results, OUTPUT_DIR) # Calculate and log overall progress overall_accuracy, correct, total = calculate_accuracy(all_results) logger.info(f"Overall Progress: {len(all_results)}/{total_examples} examples processed") logger.info(f"Current Accuracy: {overall_accuracy:.2f}% ({correct}/{total} correct)") logger.info("Processing completed!") logger.info(f"Final results saved to: {output_file}") except Exception as e: logger.error(f"Fatal error: {str(e)}") raise if __name__ == "__main__": main()