rag_ielts / fix.py
poemsforaphrodite's picture
Update fix.py
dfe0427 verified
# fix.py
import concurrent.futures
import functools
import json
import logging
import os
import re
import threading
import time
from datetime import datetime
from typing import Any, Dict, Optional
from dotenv import load_dotenv
from openai import AzureOpenAI
from ratelimiter import RateLimiter
from supabase import Client, create_client
from tqdm import tqdm
# Set up logging with thread safety and custom formatting
class CustomFormatter(logging.Formatter):
"""Custom formatter with colors and better formatting"""
grey = "\x1b[38;21m"
blue = "\x1b[38;5;39m"
yellow = "\x1b[38;5;226m"
red = "\x1b[38;5;196m"
bold_red = "\x1b[31;1m"
reset = "\x1b[0m"
def __init__(self, fmt):
super().__init__()
self.fmt = fmt
self.FORMATS = {
logging.DEBUG: self.grey + self.fmt + self.reset,
logging.INFO: self.blue + self.fmt + self.reset,
logging.WARNING: self.yellow + self.fmt + self.reset,
logging.ERROR: self.red + self.fmt + self.reset,
logging.CRITICAL: self.bold_red + self.fmt + self.reset
}
def format(self, record):
log_fmt = self.FORMATS.get(record.levelno)
formatter = logging.Formatter(log_fmt)
return formatter.format(record)
# Set up logging configuration
logger = logging.getLogger('fix')
logger.setLevel(logging.INFO)
# File handler with simple formatting
file_handler = logging.FileHandler('fix.log')
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(file_handler)
# Console handler with color formatting
console_handler = logging.StreamHandler()
console_handler.setFormatter(CustomFormatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(console_handler)
# Create a summary log file for each run
current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
summary_file = f'fix_summary_{current_time}.log'
summary_handler = logging.FileHandler(summary_file)
summary_handler.setFormatter(logging.Formatter('%(message)s'))
summary_logger = logging.getLogger('summary')
summary_logger.addHandler(summary_handler)
summary_logger.setLevel(logging.INFO)
# Load environment variables from .env file (if present)
load_dotenv()
# Constants
MIN_PASSAGE_WORDS = 100 # Minimum number of words for reading_passage
VALID_CORRECT_ANSWERS = {'A', 'B', 'C', 'D'}
EXAM_TYPES = ["SAT", "IELTS", "TOEFL"]
DIFFICULTY_LEVELS = ["Easy", "Medium", "Hard"]
# Load environment variables
SUPABASE_URL = os.getenv("SUPABASE_DB_URL")
SUPABASE_API_KEY = os.getenv("SUPABASE_API_KEY")
AZURE_OPENAI_KEY = os.getenv("AZURE_OPENAI_KEY")
AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT")
AZURE_OPENAI_DEPLOYMENT_NAME = os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME_FIX", "gpt-4o-mini") # Use specific deployment for fixing
AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION", "2023-05-15")
# Validate environment variables
missing_vars = []
if not SUPABASE_URL:
missing_vars.append("SUPABASE_DB_URL")
if not SUPABASE_API_KEY:
missing_vars.append("SUPABASE_API_KEY")
if not AZURE_OPENAI_KEY:
missing_vars.append("AZURE_OPENAI_KEY")
if not AZURE_OPENAI_ENDPOINT:
missing_vars.append("AZURE_OPENAI_ENDPOINT")
if missing_vars:
error_msg = f"Missing required environment variables: {', '.join(missing_vars)}"
logger.error(error_msg)
raise ValueError(error_msg)
# Initialize Supabase client
supabase: Client = create_client(SUPABASE_URL, SUPABASE_API_KEY)
# Initialize Azure OpenAI client
client = AzureOpenAI(
api_key=AZURE_OPENAI_KEY,
api_version=AZURE_OPENAI_API_VERSION,
azure_endpoint=AZURE_OPENAI_ENDPOINT
)
# Thread-safe counter for progress tracking
class AtomicCounter:
def __init__(self, initial=0):
self._value = initial
self._lock = threading.Lock()
def increment(self):
with self._lock:
self._value += 1
return self._value
def value(self):
with self._lock:
return self._value
class RateLimiter:
"""Rate limiter implementation using token bucket algorithm"""
def __init__(self, max_calls: int, period: float):
self.max_calls = max_calls
self.period = period
self.calls = []
self.lock = threading.Lock()
def __call__(self, func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
with self.lock:
now = time.time()
# Remove old calls outside the window
self.calls = [call for call in self.calls if call > now - self.period]
if len(self.calls) >= self.max_calls:
sleep_time = self.calls[0] - (now - self.period)
if sleep_time > 0:
time.sleep(sleep_time)
# Recalculate after sleep
now = time.time()
self.calls = [call for call in self.calls if call > now - self.period]
self.calls.append(now)
return func(*args, **kwargs)
return wrapped
# Initialize Rate Limiter: 60 calls per minute
rate_limiter = RateLimiter(max_calls=60, period=60)
@rate_limiter
def generate_fixed_content(row: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""
Uses Azure OpenAI to generate fixed content for a row.
Returns a dictionary with fixed content or None if generation fails.
"""
try:
# Determine if this is a math question
domain = row.get('domain', '').lower()
is_math = any(math_term in domain.lower() for math_term in ['math', 'algebra', 'geometry', 'calculus', 'arithmetic'])
# Create system message with domain-specific instructions
system_message = """You are an expert in standardized English test content. You must return your response as a valid JSON object with the following structure:
{
"reading_passage": "formatted passage text",
"question_text": "formatted question",
"option_a": "option A text",
"option_b": "option B text",
"option_c": "option C text",
"option_d": "option D text",
"explanation": "explanation text"
}"""
if is_math:
system_message += """
IMPORTANT: For ALL mathematics questions:
- You MUST set reading_passage to an empty string (""). No exceptions.
- Move any context or problem setup from the reading passage into the question_text
- The question_text should contain all necessary mathematical information
- Format: reading_passage must be "", question_text contains everything
Example math question format:
{
"reading_passage": "",
"question_text": "In the given system of equations, y = -1.5 and y = x^2 + 8x + a, where a is a positive constant. The system has exactly one distinct real solution. What is the value of a?",
...
}"""
else:
system_message += """
For reading comprehension questions:
- Format the reading_passage professionally with proper paragraphing
- Ensure the question is answerable from the passage
- Make answer options clear and distinct
- Reference the passage in the explanation"""
# Create user message with the content to fix
user_message = f"""Please format and fix the following exam content, returning a JSON object with the specified structure:
Domain: {domain}
Reading Passage:
{row.get('reading_passage', '')}
Question:
{row.get('question_text', '')}
Options:
A) {row.get('option_a', '')}
B) {row.get('option_b', '')}
C) {row.get('option_c', '')}
D) {row.get('option_d', '')}
Explanation:
{row.get('explanation', '')}"""
# Call Azure OpenAI API with JSON mode
response = client.chat.completions.create(
model=AZURE_OPENAI_DEPLOYMENT_NAME,
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": user_message}
],
temperature=0.3,
top_p=0.95,
frequency_penalty=0,
presence_penalty=0,
response_format={"type": "json_object"}
)
# Extract the response content
if not response.choices:
logger.error("No response generated from OpenAI")
return None
content = response.choices[0].message.content
# Calculate cost (gpt-4o-mini pricing)
input_tokens = (len(system_message) + len(user_message)) / 4 # Rough estimate: 4 chars per token
output_tokens = len(content) / 4
# gpt-4o-mini pricing:
# Input: $0.300 per 1M tokens
# Output: $1.200 per 1M tokens
fix_cost = (input_tokens / 1_000_000 * 0.300) + (output_tokens / 1_000_000 * 1.200)
logger.info(f"Estimated cost for fixing this question: ${fix_cost:.6f}")
try:
# Parse JSON response
fixed_data = json.loads(content)
# For math questions, ensure reading passage is empty
if is_math and fixed_data.get('reading_passage', '').strip():
# Move reading passage content to question text if needed
current_passage = fixed_data.get('reading_passage', '').strip()
current_question = fixed_data.get('question_text', '').strip()
if current_passage:
fixed_data['question_text'] = f"{current_passage} {current_question}"
fixed_data['reading_passage'] = ""
# Copy over unchanged fields
for key in row:
if key not in fixed_data and key != 'id':
fixed_data[key] = row[key]
# Add the fix cost to the data
fixed_data['fix_cost'] = fix_cost
return fixed_data
except json.JSONDecodeError as e:
logger.error(f"Failed to parse JSON response: {str(e)}")
return None
except Exception as e:
logger.error(f"Error processing response: {str(e)}")
return None
except Exception as e:
logger.error(f"Error generating fixed content: {str(e)}")
return None
def word_count(text: str) -> int:
"""Returns the number of words in a given text."""
return len(text.split())
def is_valid_correct_answer(answer: str) -> bool:
"""Checks if the correct_answer is one of A, B, C, D."""
return answer.upper() in VALID_CORRECT_ANSWERS
def clean_text(text: str) -> str:
"""Cleans the text by removing unwanted characters and extra spaces."""
text = re.sub(r'\s+', ' ', text) # Replace multiple spaces with single space
text = text.strip()
return text
def check_row_quality(row: Dict[str, Any]) -> bool:
"""
Check if a row meets quality standards.
Returns True if the row is good quality, False if it needs fixing.
"""
# Skip if already marked as fixed
if row.get('is_fixed', False):
return True
# Check for image-related questions that should be deleted
question_text = row.get('question_text', '').lower()
reading_passage = row.get('reading_passage', '').lower()
# Keywords that indicate image-based questions
image_keywords = [
'image', 'picture', 'diagram', 'graph', 'figure', 'photo', 'illustration',
'shown', 'depicted', 'displayed', 'above', 'below', 'following figure',
'look at the', 'in this picture', 'as shown', 'pictured'
]
# Check if question or passage refers to images
if any(keyword in question_text for keyword in image_keywords) or \
any(keyword in reading_passage for keyword in image_keywords):
logger.info(f"Row {row.get('id')}: Marked for deletion - contains image references")
return None # Return None to indicate deletion
# Basic validation for required fields
if not row.get('question_text') or not row.get('explanation'):
logger.info(f"Row {row.get('id')}: Marked for deletion - missing required fields")
return None
if not all(row.get(f'option_{opt}') for opt in ['a', 'b', 'c', 'd']):
logger.info(f"Row {row.get('id')}: Marked for deletion - missing options")
return None
if not is_valid_correct_answer(row.get('correct_answer', '')):
logger.info(f"Row {row.get('id')}: Marked for deletion - invalid correct answer")
return None
# Option quality checks
options = [row.get(f'option_{opt}', '').strip() for opt in ['a', 'b', 'c', 'd']]
if any(len(opt) < 1 for opt in options):
logger.info(f"Row {row.get('id')}: Marked for deletion - empty options")
return None
# Check for duplicate options
if len(set(options)) != 4:
logger.info(f"Row {row.get('id')}: Marked for deletion - duplicate options")
return None
# Basic explanation quality check
explanation = row.get('explanation', '')
if len(explanation) < 50 or not explanation.strip():
logger.info(f"Row {row.get('id')}: Marked for deletion - insufficient explanation")
return None
return True
def update_row_in_supabase(row_id: str, fixed_data: Dict[str, Any]) -> bool:
"""
Updates a row in Supabase with fixed data.
Returns True if successful, False otherwise.
"""
try:
response = supabase.table("exam_contents").update(fixed_data).eq("id", row_id).execute()
# Check if data exists in the response
if response.data:
logger.debug(f"HTTP Request: PATCH https://{SUPABASE_URL}/rest/v1/exam_contents?id=eq.{row_id} \"HTTP/2 200 OK\"")
logger.info(f"Row {row_id}: Successfully updated.")
return True
else:
logger.error(f"Row {row_id}: Failed to update.")
return False
except Exception as e:
logger.error(f"Row {row_id}: Exception while updating - {str(e)}")
return False
def process_row(row: Dict[str, Any], progress_counter: AtomicCounter, total_rows: int, row_number: int) -> Dict[str, Any]:
"""Process a single row and return the result."""
try:
row_id = row.get('id')
# Check quality first
quality_check = check_row_quality(row)
# If quality_check is None, delete the row
if quality_check is None:
try:
supabase.table("exam_contents").delete().eq("id", row_id).execute()
logger.info(f"Row {row_id}: Deleted due to quality issues.")
return {
'success': True,
'changes_made': ['deleted'],
'row_id': row_id,
'cost': 0.0
}
except Exception as e:
logger.error(f"Row {row_id}: Failed to delete - {str(e)}")
return {
'success': False,
'row_id': row_id,
'cost': 0.0
}
# If row passes quality check, no need to fix
if quality_check is True:
# Update is_fixed flag
try:
supabase.table("exam_contents").update({"is_fixed": True}).eq("id", row_id).execute()
logger.info(f"Row {row_id}: Already good quality. Marked as fixed.")
return {
'success': True,
'changes_made': ['marked_fixed'],
'row_id': row_id,
'cost': 0.0
}
except Exception as e:
logger.error(f"Row {row_id}: Failed to update fixed status - {str(e)}")
return {
'success': False,
'row_id': row_id,
'cost': 0.0
}
# Generate fixed content
fixed_data = generate_fixed_content(row)
if not fixed_data:
logger.error(f"Row {row_id}: Failed to generate fixed content.")
return {
'success': False,
'row_id': row_id,
'cost': 0.0
}
# Track what fields were modified
changes_made = []
for field in fixed_data:
if field in row and fixed_data[field] != row[field]:
changes_made.append(field)
if changes_made:
# Add is_fixed flag
fixed_data['is_fixed'] = True
# Update in database
try:
supabase.table("exam_contents").update(fixed_data).eq("id", row_id).execute()
change_list = ', '.join(changes_made)
logger.info(f"Row {row_id}: Fixed successfully. Modified: {change_list}")
return {
'success': True,
'changes_made': changes_made,
'row_id': row_id,
'cost': fixed_data.get('fix_cost', 0.0) # Include the fix cost
}
except Exception as e:
logger.error(f"Row {row_id}: Failed to update - {str(e)}")
return {
'success': False,
'row_id': row_id,
'cost': 0.0
}
else:
# No changes needed, just mark as fixed
try:
supabase.table("exam_contents").update({"is_fixed": True}).eq("id", row_id).execute()
logger.info(f"Row {row_id}: Fixed successfully. Modified: No changes needed")
return {
'success': True,
'changes_made': ['marked_fixed'],
'row_id': row_id,
'cost': fixed_data.get('fix_cost', 0.0) # Include the fix cost even if no changes
}
except Exception as e:
logger.error(f"Row {row_id}: Failed to update fixed status - {str(e)}")
return {
'success': False,
'row_id': row_id,
'cost': 0.0
}
except Exception as e:
logger.error(f"Error processing row {row.get('id', 'unknown')}: {str(e)}")
return {
'success': False,
'row_id': row.get('id', 'unknown'),
'cost': 0.0
}
def fetch_all_unfixed_rows(supabase_client: Client, batch_size: int = 1000):
"""
Fetches all unfixed rows from the exam_contents table in batches.
Args:
supabase_client (Client): The Supabase client instance.
batch_size (int): Number of rows to fetch per batch.
Yields:
List[Dict[str, Any]]: A batch of rows.
"""
# Initialize the starting range
start = 0
while True:
# Fetch a batch of rows
response = supabase_client.table("exam_contents")\
.select("*")\
.eq("is_fixed", False)\
.range(start, start + batch_size - 1)\
.execute()
batch = response.data
if not batch:
break # No more rows to fetch
yield batch
start += batch_size
def main():
"""Main function to process and fix exam questions in Supabase using multithreading."""
start_time = time.time()
logger.info("Starting fix.py script")
summary_logger.info("\n=== Question Fix Summary ===\n")
try:
# Initialize counters
total_rows = 0
success_count = 0
failure_count = 0
total_cost = 0.0
changes_by_field = {
'reading_passage': 0,
'question_text': 0,
'option_a': 0,
'option_b': 0,
'option_c': 0,
'option_d': 0,
'explanation': 0
}
# Create a thread pool
max_workers = min(32, os.cpu_count() * 2) # Adjust based on CPU cores
logger.info(f"Initializing with {max_workers} threads")
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
# Initialize progress tracking
progress_counter = AtomicCounter()
futures = []
# Process rows in batches
for batch in fetch_all_unfixed_rows(supabase):
total_rows += len(batch)
for i, row in enumerate(batch):
future = executor.submit(process_row, row, progress_counter, total_rows, i + 1)
futures.append(future)
# Track progress with tqdm
with tqdm(total=total_rows, desc="Processing Rows", unit="row", dynamic_ncols=True) as pbar:
for future in concurrent.futures.as_completed(futures):
result = future.result()
if result['success']:
success_count += 1
# Update changes counter
for field in result['changes_made']:
changes_by_field[field] = changes_by_field.get(field, 0) + 1
# Add cost if available
if 'cost' in result:
total_cost += result['cost']
else:
failure_count += 1
pbar.update(1)
# Calculate execution time
execution_time = time.time() - start_time
# Log final statistics
summary = [
"\n=== Final Statistics ===",
f"Total questions processed: {total_rows}",
f"Successful updates: {success_count}",
f"Failed updates: {failure_count}",
f"Total cost: ${total_cost:.6f}",
f"Execution time: {execution_time:.2f} seconds",
"\nChanges by field:",
*[f"- {field}: {count}" for field, count in changes_by_field.items() if count > 0],
"\n=== End of Summary ===\n"
]
# Log to both console and summary file
for line in summary:
logger.info(line)
summary_logger.info(line)
except Exception as e:
error_msg = f"An unexpected error occurred: {str(e)}"
logger.error(error_msg)
summary_logger.error(error_msg)
if __name__ == "__main__":
main()