medrax.org / experiments /chexbench_gpt4.py
oldcai's picture
Upload folder using huggingface_hub
d7a7846 verified
import json
import openai
import os
from datetime import datetime
import base64
import logging
from pathlib import Path
import time
from tqdm import tqdm
from typing import Dict, List, Optional, Union, Any
# Configuration constants
DEBUG_MODE = False
OUTPUT_DIR = "results"
MODEL_NAME = "gpt-4o-2024-05-13"
TEMPERATURE = 0.2
SUBSET = "Visual Question Answering"
# Set up logging configuration
logging_level = logging.DEBUG if DEBUG_MODE else logging.INFO
logging.basicConfig(level=logging_level, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
def get_mime_type(file_path: str) -> str:
"""
Determine MIME type based on file extension.
Args:
file_path (str): Path to the file
Returns:
str: MIME type string for the file
"""
extension = os.path.splitext(file_path)[1].lower()
mime_types = {
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".gif": "image/gif",
}
return mime_types.get(extension, "application/octet-stream")
def encode_image(image_path: str) -> str:
"""
Encode image to base64 with extensive error checking.
Args:
image_path (str): Path to the image file
Returns:
str: Base64 encoded image string
Raises:
FileNotFoundError: If image file does not exist
ValueError: If image file is empty or too large
Exception: For other image processing errors
"""
logger.debug(f"Attempting to read image from: {image_path}")
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image file not found: {image_path}")
# Add check for file size
file_size = os.path.getsize(image_path)
if file_size > 20 * 1024 * 1024: # 20MB limit
raise ValueError("Image file size exceeds 20MB limit")
if file_size == 0:
raise ValueError("Image file is empty")
logger.debug(f"Image file size: {file_size / 1024:.2f} KB")
try:
from PIL import Image
# Try to open and verify the image
with Image.open(image_path) as img:
# Get image details
width, height = img.size
format = img.format
mode = img.mode
logger.debug(
f"Image verification - Format: {format}, Size: {width}x{height}, Mode: {mode}"
)
if format not in ["PNG", "JPEG", "GIF"]:
raise ValueError(f"Unsupported image format: {format}")
with open(image_path, "rb") as image_file:
# Read the first few bytes to verify it's a valid PNG
header = image_file.read(8)
# if header != b'\x89PNG\r\n\x1a\n':
# logger.warning("File does not have a valid PNG signature")
# Reset file pointer and read entire file
image_file.seek(0)
encoded = base64.b64encode(image_file.read()).decode("utf-8")
encoded_length = len(encoded)
logger.debug(f"Base64 encoded length: {encoded_length} characters")
# Verify the encoded string is not empty and starts correctly
if encoded_length == 0:
raise ValueError("Base64 encoding produced empty string")
if not encoded.startswith("/9j/") and not encoded.startswith("iVBOR"):
logger.warning("Base64 string doesn't start with expected JPEG or PNG header")
return encoded
except Exception as e:
logger.error(f"Error reading/encoding image: {str(e)}")
raise
def create_single_request(
image_path: str, question: str, options: Dict[str, str]
) -> List[Dict[str, Any]]:
"""
Create a single API request with image and question.
Args:
image_path (str): Path to the image file
question (str): Question text
options (Dict[str, str]): Dictionary containing options with keys 'option_0' and 'option_1'
Returns:
List[Dict[str, Any]]: List of message dictionaries for the API request
Raises:
Exception: For errors in request creation
"""
if DEBUG_MODE:
logger.debug("Creating API request...")
prompt = f"""Given the following medical examination question:
Please answer this multiple choice question:
Question: {question}
Options:
A) {options['option_0']}
B) {options['option_1']}
Base your answer only on the provided image and select either A or B."""
try:
encoded_image = encode_image(image_path)
mime_type = get_mime_type(image_path)
if DEBUG_MODE:
logger.debug(f"Image encoded with MIME type: {mime_type}")
messages = [
{
"role": "system",
"content": "You are taking a medical exam. Answer ONLY with the letter (A/B) corresponding to your answer.",
},
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {"url": f"data:{mime_type};base64,{encoded_image}"},
},
],
},
]
if DEBUG_MODE:
log_messages = json.loads(json.dumps(messages))
log_messages[1]["content"][1]["image_url"][
"url"
] = f"data:{mime_type};base64,[BASE64_IMAGE_TRUNCATED]"
logger.debug(f"Complete API request payload:\n{json.dumps(log_messages, indent=2)}")
return messages
except Exception as e:
logger.error(f"Error creating request: {str(e)}")
raise
def check_answer(model_answer: str, correct_answer: int) -> bool:
"""
Check if the model's answer matches the correct answer.
Args:
model_answer (str): The model's answer (A or B)
correct_answer (int): The correct answer index (0 for A, 1 for B)
Returns:
bool: True if answer is correct, False otherwise
"""
if not isinstance(model_answer, str):
return False
# Clean the model answer to get just the letter
model_letter = model_answer.strip().upper()
if model_letter.startswith("A"):
model_index = 0
elif model_letter.startswith("B"):
model_index = 1
else:
return False
return model_index == correct_answer
def save_results_to_json(results: List[Dict[str, Any]], output_dir: str) -> str:
"""
Save results to a JSON file with timestamp.
Args:
results (List[Dict[str, Any]]): List of result dictionaries
output_dir (str): Directory to save results
Returns:
str: Path to the saved file
"""
Path(output_dir).mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_file = os.path.join(output_dir, f"batch_results_{timestamp}.json")
with open(output_file, "w") as f:
json.dump(results, f, indent=2)
logger.info(f"Batch results saved to {output_file}")
return output_file
def calculate_accuracy(results: List[Dict[str, Any]]) -> tuple[float, int, int]:
"""
Calculate accuracy from results, handling error cases.
Args:
results (List[Dict[str, Any]]): List of result dictionaries
Returns:
tuple[float, int, int]: Tuple containing (accuracy percentage, number correct, total)
"""
if not results:
return 0.0, 0, 0
total = len(results)
valid_results = [r for r in results if "output" in r]
correct = sum(
1 for result in valid_results if result.get("output", {}).get("is_correct", False)
)
accuracy = (correct / total * 100) if total > 0 else 0
return accuracy, correct, total
def calculate_batch_accuracy(results: List[Dict[str, Any]]) -> float:
"""
Calculate accuracy for the current batch.
Args:
results (List[Dict[str, Any]]): List of result dictionaries
Returns:
float: Accuracy percentage for the batch
"""
valid_results = [r for r in results if "output" in r]
if not valid_results:
return 0.0
return sum(1 for r in valid_results if r["output"]["is_correct"]) / len(valid_results) * 100
def process_batch(
data: List[Dict[str, Any]], client: openai.OpenAI, start_idx: int = 0, batch_size: int = 50
) -> List[Dict[str, Any]]:
"""
Process a batch of examples and return results.
Args:
data (List[Dict[str, Any]]): List of data items to process
client (openai.OpenAI): OpenAI client instance
start_idx (int, optional): Starting index for batch. Defaults to 0
batch_size (int, optional): Size of batch to process. Defaults to 50
Returns:
List[Dict[str, Any]]: List of processed results
"""
batch_results = []
end_idx = min(start_idx + batch_size, len(data))
pbar = tqdm(
range(start_idx, end_idx),
desc=f"Processing batch {start_idx//batch_size + 1}",
unit="example",
)
for index in pbar:
vqa_item = data[index]
options = {"option_0": vqa_item["option_0"], "option_1": vqa_item["option_1"]}
try:
messages = create_single_request(
image_path=vqa_item["image_path"], question=vqa_item["question"], options=options
)
response = client.chat.completions.create(
model=MODEL_NAME, messages=messages, max_tokens=50, temperature=TEMPERATURE
)
model_answer = response.choices[0].message.content.strip()
is_correct = check_answer(model_answer, vqa_item["answer"])
result = {
"timestamp": datetime.now().isoformat(),
"example_index": index,
"input": {
"question": vqa_item["question"],
"options": {"A": vqa_item["option_0"], "B": vqa_item["option_1"]},
"image_path": vqa_item["image_path"],
},
"output": {
"model_answer": model_answer,
"correct_answer": "A" if vqa_item["answer"] == 0 else "B",
"is_correct": is_correct,
"usage": {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
},
},
}
batch_results.append(result)
# Update progress bar with current accuracy
current_accuracy = calculate_batch_accuracy(batch_results)
pbar.set_description(
f"Batch {start_idx//batch_size + 1} - Accuracy: {current_accuracy:.2f}% "
f"({len(batch_results)}/{index-start_idx+1} examples)"
)
except Exception as e:
error_result = {
"timestamp": datetime.now().isoformat(),
"example_index": index,
"error": str(e),
"input": {
"question": vqa_item["question"],
"options": {"A": vqa_item["option_0"], "B": vqa_item["option_1"]},
"image_path": vqa_item["image_path"],
},
}
batch_results.append(error_result)
if DEBUG_MODE:
pbar.write(f"Error processing example {index}: {str(e)}")
time.sleep(1) # Rate limiting
return batch_results
def main() -> None:
"""
Main function to process the entire dataset.
Raises:
ValueError: If OPENAI_API_KEY is not set
Exception: For other processing errors
"""
logger.info("Starting full dataset processing...")
json_path = "../data/chexbench_updated.json"
try:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("OPENAI_API_KEY environment variable is not set.")
client = openai.OpenAI(api_key=api_key)
with open(json_path, "r") as f:
data = json.load(f)
subset_data = data[SUBSET]
total_examples = len(subset_data)
logger.info(f"Found {total_examples} examples in {SUBSET} subset")
all_results = []
batch_size = 50 # Process in batches of 50 examples
# Process all examples in batches
for start_idx in range(0, total_examples, batch_size):
batch_results = process_batch(subset_data, client, start_idx, batch_size)
all_results.extend(batch_results)
# Save intermediate results after each batch
output_file = save_results_to_json(all_results, OUTPUT_DIR)
# Calculate and log overall progress
overall_accuracy, correct, total = calculate_accuracy(all_results)
logger.info(f"Overall Progress: {len(all_results)}/{total_examples} examples processed")
logger.info(f"Current Accuracy: {overall_accuracy:.2f}% ({correct}/{total} correct)")
logger.info("Processing completed!")
logger.info(f"Final results saved to: {output_file}")
except Exception as e:
logger.error(f"Fatal error: {str(e)}")
raise
if __name__ == "__main__":
main()