Spaces:
Sleeping
Sleeping
# fix.py | |
import concurrent.futures | |
import functools | |
import json | |
import logging | |
import os | |
import re | |
import threading | |
import time | |
from datetime import datetime | |
from typing import Any, Dict, Optional | |
from dotenv import load_dotenv | |
from openai import AzureOpenAI | |
from ratelimiter import RateLimiter | |
from supabase import Client, create_client | |
from tqdm import tqdm | |
# Set up logging with thread safety and custom formatting | |
class CustomFormatter(logging.Formatter): | |
"""Custom formatter with colors and better formatting""" | |
grey = "\x1b[38;21m" | |
blue = "\x1b[38;5;39m" | |
yellow = "\x1b[38;5;226m" | |
red = "\x1b[38;5;196m" | |
bold_red = "\x1b[31;1m" | |
reset = "\x1b[0m" | |
def __init__(self, fmt): | |
super().__init__() | |
self.fmt = fmt | |
self.FORMATS = { | |
logging.DEBUG: self.grey + self.fmt + self.reset, | |
logging.INFO: self.blue + self.fmt + self.reset, | |
logging.WARNING: self.yellow + self.fmt + self.reset, | |
logging.ERROR: self.red + self.fmt + self.reset, | |
logging.CRITICAL: self.bold_red + self.fmt + self.reset | |
} | |
def format(self, record): | |
log_fmt = self.FORMATS.get(record.levelno) | |
formatter = logging.Formatter(log_fmt) | |
return formatter.format(record) | |
# Set up logging configuration | |
logger = logging.getLogger('fix') | |
logger.setLevel(logging.INFO) | |
# File handler with simple formatting | |
file_handler = logging.FileHandler('fix.log') | |
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) | |
logger.addHandler(file_handler) | |
# Console handler with color formatting | |
console_handler = logging.StreamHandler() | |
console_handler.setFormatter(CustomFormatter('%(asctime)s - %(levelname)s - %(message)s')) | |
logger.addHandler(console_handler) | |
# Create a summary log file for each run | |
current_time = datetime.now().strftime('%Y%m%d_%H%M%S') | |
summary_file = f'fix_summary_{current_time}.log' | |
summary_handler = logging.FileHandler(summary_file) | |
summary_handler.setFormatter(logging.Formatter('%(message)s')) | |
summary_logger = logging.getLogger('summary') | |
summary_logger.addHandler(summary_handler) | |
summary_logger.setLevel(logging.INFO) | |
# Load environment variables from .env file (if present) | |
load_dotenv() | |
# Constants | |
MIN_PASSAGE_WORDS = 100 # Minimum number of words for reading_passage | |
VALID_CORRECT_ANSWERS = {'A', 'B', 'C', 'D'} | |
EXAM_TYPES = ["SAT", "IELTS", "TOEFL"] | |
DIFFICULTY_LEVELS = ["Easy", "Medium", "Hard"] | |
# Load environment variables | |
SUPABASE_URL = os.getenv("SUPABASE_DB_URL") | |
SUPABASE_API_KEY = os.getenv("SUPABASE_API_KEY") | |
AZURE_OPENAI_KEY = os.getenv("AZURE_OPENAI_KEY") | |
AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT") | |
AZURE_OPENAI_DEPLOYMENT_NAME = os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME_FIX", "gpt-4o-mini") # Use specific deployment for fixing | |
AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION", "2023-05-15") | |
# Validate environment variables | |
missing_vars = [] | |
if not SUPABASE_URL: | |
missing_vars.append("SUPABASE_DB_URL") | |
if not SUPABASE_API_KEY: | |
missing_vars.append("SUPABASE_API_KEY") | |
if not AZURE_OPENAI_KEY: | |
missing_vars.append("AZURE_OPENAI_KEY") | |
if not AZURE_OPENAI_ENDPOINT: | |
missing_vars.append("AZURE_OPENAI_ENDPOINT") | |
if missing_vars: | |
error_msg = f"Missing required environment variables: {', '.join(missing_vars)}" | |
logger.error(error_msg) | |
raise ValueError(error_msg) | |
# Initialize Supabase client | |
supabase: Client = create_client(SUPABASE_URL, SUPABASE_API_KEY) | |
# Initialize Azure OpenAI client | |
client = AzureOpenAI( | |
api_key=AZURE_OPENAI_KEY, | |
api_version=AZURE_OPENAI_API_VERSION, | |
azure_endpoint=AZURE_OPENAI_ENDPOINT | |
) | |
# Thread-safe counter for progress tracking | |
class AtomicCounter: | |
def __init__(self, initial=0): | |
self._value = initial | |
self._lock = threading.Lock() | |
def increment(self): | |
with self._lock: | |
self._value += 1 | |
return self._value | |
def value(self): | |
with self._lock: | |
return self._value | |
class RateLimiter: | |
"""Rate limiter implementation using token bucket algorithm""" | |
def __init__(self, max_calls: int, period: float): | |
self.max_calls = max_calls | |
self.period = period | |
self.calls = [] | |
self.lock = threading.Lock() | |
def __call__(self, func): | |
def wrapped(*args, **kwargs): | |
with self.lock: | |
now = time.time() | |
# Remove old calls outside the window | |
self.calls = [call for call in self.calls if call > now - self.period] | |
if len(self.calls) >= self.max_calls: | |
sleep_time = self.calls[0] - (now - self.period) | |
if sleep_time > 0: | |
time.sleep(sleep_time) | |
# Recalculate after sleep | |
now = time.time() | |
self.calls = [call for call in self.calls if call > now - self.period] | |
self.calls.append(now) | |
return func(*args, **kwargs) | |
return wrapped | |
# Initialize Rate Limiter: 60 calls per minute | |
rate_limiter = RateLimiter(max_calls=60, period=60) | |
def generate_fixed_content(row: Dict[str, Any]) -> Optional[Dict[str, Any]]: | |
""" | |
Uses Azure OpenAI to generate fixed content for a row. | |
Returns a dictionary with fixed content or None if generation fails. | |
""" | |
try: | |
# Determine if this is a math question | |
domain = row.get('domain', '').lower() | |
is_math = any(math_term in domain.lower() for math_term in ['math', 'algebra', 'geometry', 'calculus', 'arithmetic']) | |
# Create system message with domain-specific instructions | |
system_message = """You are an expert in standardized English test content. You must return your response as a valid JSON object with the following structure: | |
{ | |
"reading_passage": "formatted passage text", | |
"question_text": "formatted question", | |
"option_a": "option A text", | |
"option_b": "option B text", | |
"option_c": "option C text", | |
"option_d": "option D text", | |
"explanation": "explanation text" | |
}""" | |
if is_math: | |
system_message += """ | |
IMPORTANT: For ALL mathematics questions: | |
- You MUST set reading_passage to an empty string (""). No exceptions. | |
- Move any context or problem setup from the reading passage into the question_text | |
- The question_text should contain all necessary mathematical information | |
- Format: reading_passage must be "", question_text contains everything | |
Example math question format: | |
{ | |
"reading_passage": "", | |
"question_text": "In the given system of equations, y = -1.5 and y = x^2 + 8x + a, where a is a positive constant. The system has exactly one distinct real solution. What is the value of a?", | |
... | |
}""" | |
else: | |
system_message += """ | |
For reading comprehension questions: | |
- Format the reading_passage professionally with proper paragraphing | |
- Ensure the question is answerable from the passage | |
- Make answer options clear and distinct | |
- Reference the passage in the explanation""" | |
# Create user message with the content to fix | |
user_message = f"""Please format and fix the following exam content, returning a JSON object with the specified structure: | |
Domain: {domain} | |
Reading Passage: | |
{row.get('reading_passage', '')} | |
Question: | |
{row.get('question_text', '')} | |
Options: | |
A) {row.get('option_a', '')} | |
B) {row.get('option_b', '')} | |
C) {row.get('option_c', '')} | |
D) {row.get('option_d', '')} | |
Explanation: | |
{row.get('explanation', '')}""" | |
# Call Azure OpenAI API with JSON mode | |
response = client.chat.completions.create( | |
model=AZURE_OPENAI_DEPLOYMENT_NAME, | |
messages=[ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": user_message} | |
], | |
temperature=0.3, | |
top_p=0.95, | |
frequency_penalty=0, | |
presence_penalty=0, | |
response_format={"type": "json_object"} | |
) | |
# Extract the response content | |
if not response.choices: | |
logger.error("No response generated from OpenAI") | |
return None | |
content = response.choices[0].message.content | |
# Calculate cost (gpt-4o-mini pricing) | |
input_tokens = (len(system_message) + len(user_message)) / 4 # Rough estimate: 4 chars per token | |
output_tokens = len(content) / 4 | |
# gpt-4o-mini pricing: | |
# Input: $0.300 per 1M tokens | |
# Output: $1.200 per 1M tokens | |
fix_cost = (input_tokens / 1_000_000 * 0.300) + (output_tokens / 1_000_000 * 1.200) | |
logger.info(f"Estimated cost for fixing this question: ${fix_cost:.6f}") | |
try: | |
# Parse JSON response | |
fixed_data = json.loads(content) | |
# For math questions, ensure reading passage is empty | |
if is_math and fixed_data.get('reading_passage', '').strip(): | |
# Move reading passage content to question text if needed | |
current_passage = fixed_data.get('reading_passage', '').strip() | |
current_question = fixed_data.get('question_text', '').strip() | |
if current_passage: | |
fixed_data['question_text'] = f"{current_passage} {current_question}" | |
fixed_data['reading_passage'] = "" | |
# Copy over unchanged fields | |
for key in row: | |
if key not in fixed_data and key != 'id': | |
fixed_data[key] = row[key] | |
# Add the fix cost to the data | |
fixed_data['fix_cost'] = fix_cost | |
return fixed_data | |
except json.JSONDecodeError as e: | |
logger.error(f"Failed to parse JSON response: {str(e)}") | |
return None | |
except Exception as e: | |
logger.error(f"Error processing response: {str(e)}") | |
return None | |
except Exception as e: | |
logger.error(f"Error generating fixed content: {str(e)}") | |
return None | |
def word_count(text: str) -> int: | |
"""Returns the number of words in a given text.""" | |
return len(text.split()) | |
def is_valid_correct_answer(answer: str) -> bool: | |
"""Checks if the correct_answer is one of A, B, C, D.""" | |
return answer.upper() in VALID_CORRECT_ANSWERS | |
def clean_text(text: str) -> str: | |
"""Cleans the text by removing unwanted characters and extra spaces.""" | |
text = re.sub(r'\s+', ' ', text) # Replace multiple spaces with single space | |
text = text.strip() | |
return text | |
def check_row_quality(row: Dict[str, Any]) -> bool: | |
""" | |
Check if a row meets quality standards. | |
Returns True if the row is good quality, False if it needs fixing. | |
""" | |
# Skip if already marked as fixed | |
if row.get('is_fixed', False): | |
return True | |
# Check for image-related questions that should be deleted | |
question_text = row.get('question_text', '').lower() | |
reading_passage = row.get('reading_passage', '').lower() | |
# Keywords that indicate image-based questions | |
image_keywords = [ | |
'image', 'picture', 'diagram', 'graph', 'figure', 'photo', 'illustration', | |
'shown', 'depicted', 'displayed', 'above', 'below', 'following figure', | |
'look at the', 'in this picture', 'as shown', 'pictured' | |
] | |
# Check if question or passage refers to images | |
if any(keyword in question_text for keyword in image_keywords) or \ | |
any(keyword in reading_passage for keyword in image_keywords): | |
logger.info(f"Row {row.get('id')}: Marked for deletion - contains image references") | |
return None # Return None to indicate deletion | |
# Basic validation for required fields | |
if not row.get('question_text') or not row.get('explanation'): | |
logger.info(f"Row {row.get('id')}: Marked for deletion - missing required fields") | |
return None | |
if not all(row.get(f'option_{opt}') for opt in ['a', 'b', 'c', 'd']): | |
logger.info(f"Row {row.get('id')}: Marked for deletion - missing options") | |
return None | |
if not is_valid_correct_answer(row.get('correct_answer', '')): | |
logger.info(f"Row {row.get('id')}: Marked for deletion - invalid correct answer") | |
return None | |
# Option quality checks | |
options = [row.get(f'option_{opt}', '').strip() for opt in ['a', 'b', 'c', 'd']] | |
if any(len(opt) < 1 for opt in options): | |
logger.info(f"Row {row.get('id')}: Marked for deletion - empty options") | |
return None | |
# Check for duplicate options | |
if len(set(options)) != 4: | |
logger.info(f"Row {row.get('id')}: Marked for deletion - duplicate options") | |
return None | |
# Basic explanation quality check | |
explanation = row.get('explanation', '') | |
if len(explanation) < 50 or not explanation.strip(): | |
logger.info(f"Row {row.get('id')}: Marked for deletion - insufficient explanation") | |
return None | |
return True | |
def update_row_in_supabase(row_id: str, fixed_data: Dict[str, Any]) -> bool: | |
""" | |
Updates a row in Supabase with fixed data. | |
Returns True if successful, False otherwise. | |
""" | |
try: | |
response = supabase.table("exam_contents").update(fixed_data).eq("id", row_id).execute() | |
# Check if data exists in the response | |
if response.data: | |
logger.debug(f"HTTP Request: PATCH https://{SUPABASE_URL}/rest/v1/exam_contents?id=eq.{row_id} \"HTTP/2 200 OK\"") | |
logger.info(f"Row {row_id}: Successfully updated.") | |
return True | |
else: | |
logger.error(f"Row {row_id}: Failed to update.") | |
return False | |
except Exception as e: | |
logger.error(f"Row {row_id}: Exception while updating - {str(e)}") | |
return False | |
def process_row(row: Dict[str, Any], progress_counter: AtomicCounter, total_rows: int, row_number: int) -> Dict[str, Any]: | |
"""Process a single row and return the result.""" | |
try: | |
row_id = row.get('id') | |
# Check quality first | |
quality_check = check_row_quality(row) | |
# If quality_check is None, delete the row | |
if quality_check is None: | |
try: | |
supabase.table("exam_contents").delete().eq("id", row_id).execute() | |
logger.info(f"Row {row_id}: Deleted due to quality issues.") | |
return { | |
'success': True, | |
'changes_made': ['deleted'], | |
'row_id': row_id, | |
'cost': 0.0 | |
} | |
except Exception as e: | |
logger.error(f"Row {row_id}: Failed to delete - {str(e)}") | |
return { | |
'success': False, | |
'row_id': row_id, | |
'cost': 0.0 | |
} | |
# If row passes quality check, no need to fix | |
if quality_check is True: | |
# Update is_fixed flag | |
try: | |
supabase.table("exam_contents").update({"is_fixed": True}).eq("id", row_id).execute() | |
logger.info(f"Row {row_id}: Already good quality. Marked as fixed.") | |
return { | |
'success': True, | |
'changes_made': ['marked_fixed'], | |
'row_id': row_id, | |
'cost': 0.0 | |
} | |
except Exception as e: | |
logger.error(f"Row {row_id}: Failed to update fixed status - {str(e)}") | |
return { | |
'success': False, | |
'row_id': row_id, | |
'cost': 0.0 | |
} | |
# Generate fixed content | |
fixed_data = generate_fixed_content(row) | |
if not fixed_data: | |
logger.error(f"Row {row_id}: Failed to generate fixed content.") | |
return { | |
'success': False, | |
'row_id': row_id, | |
'cost': 0.0 | |
} | |
# Track what fields were modified | |
changes_made = [] | |
for field in fixed_data: | |
if field in row and fixed_data[field] != row[field]: | |
changes_made.append(field) | |
if changes_made: | |
# Add is_fixed flag | |
fixed_data['is_fixed'] = True | |
# Update in database | |
try: | |
supabase.table("exam_contents").update(fixed_data).eq("id", row_id).execute() | |
change_list = ', '.join(changes_made) | |
logger.info(f"Row {row_id}: Fixed successfully. Modified: {change_list}") | |
return { | |
'success': True, | |
'changes_made': changes_made, | |
'row_id': row_id, | |
'cost': fixed_data.get('fix_cost', 0.0) # Include the fix cost | |
} | |
except Exception as e: | |
logger.error(f"Row {row_id}: Failed to update - {str(e)}") | |
return { | |
'success': False, | |
'row_id': row_id, | |
'cost': 0.0 | |
} | |
else: | |
# No changes needed, just mark as fixed | |
try: | |
supabase.table("exam_contents").update({"is_fixed": True}).eq("id", row_id).execute() | |
logger.info(f"Row {row_id}: Fixed successfully. Modified: No changes needed") | |
return { | |
'success': True, | |
'changes_made': ['marked_fixed'], | |
'row_id': row_id, | |
'cost': fixed_data.get('fix_cost', 0.0) # Include the fix cost even if no changes | |
} | |
except Exception as e: | |
logger.error(f"Row {row_id}: Failed to update fixed status - {str(e)}") | |
return { | |
'success': False, | |
'row_id': row_id, | |
'cost': 0.0 | |
} | |
except Exception as e: | |
logger.error(f"Error processing row {row.get('id', 'unknown')}: {str(e)}") | |
return { | |
'success': False, | |
'row_id': row.get('id', 'unknown'), | |
'cost': 0.0 | |
} | |
def fetch_all_unfixed_rows(supabase_client: Client, batch_size: int = 1000): | |
""" | |
Fetches all unfixed rows from the exam_contents table in batches. | |
Args: | |
supabase_client (Client): The Supabase client instance. | |
batch_size (int): Number of rows to fetch per batch. | |
Yields: | |
List[Dict[str, Any]]: A batch of rows. | |
""" | |
# Initialize the starting range | |
start = 0 | |
while True: | |
# Fetch a batch of rows | |
response = supabase_client.table("exam_contents")\ | |
.select("*")\ | |
.eq("is_fixed", False)\ | |
.range(start, start + batch_size - 1)\ | |
.execute() | |
batch = response.data | |
if not batch: | |
break # No more rows to fetch | |
yield batch | |
start += batch_size | |
def main(): | |
"""Main function to process and fix exam questions in Supabase using multithreading.""" | |
start_time = time.time() | |
logger.info("Starting fix.py script") | |
summary_logger.info("\n=== Question Fix Summary ===\n") | |
try: | |
# Initialize counters | |
total_rows = 0 | |
success_count = 0 | |
failure_count = 0 | |
total_cost = 0.0 | |
changes_by_field = { | |
'reading_passage': 0, | |
'question_text': 0, | |
'option_a': 0, | |
'option_b': 0, | |
'option_c': 0, | |
'option_d': 0, | |
'explanation': 0 | |
} | |
# Create a thread pool | |
max_workers = min(32, os.cpu_count() * 2) # Adjust based on CPU cores | |
logger.info(f"Initializing with {max_workers} threads") | |
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: | |
# Initialize progress tracking | |
progress_counter = AtomicCounter() | |
futures = [] | |
# Process rows in batches | |
for batch in fetch_all_unfixed_rows(supabase): | |
total_rows += len(batch) | |
for i, row in enumerate(batch): | |
future = executor.submit(process_row, row, progress_counter, total_rows, i + 1) | |
futures.append(future) | |
# Track progress with tqdm | |
with tqdm(total=total_rows, desc="Processing Rows", unit="row", dynamic_ncols=True) as pbar: | |
for future in concurrent.futures.as_completed(futures): | |
result = future.result() | |
if result['success']: | |
success_count += 1 | |
# Update changes counter | |
for field in result['changes_made']: | |
changes_by_field[field] = changes_by_field.get(field, 0) + 1 | |
# Add cost if available | |
if 'cost' in result: | |
total_cost += result['cost'] | |
else: | |
failure_count += 1 | |
pbar.update(1) | |
# Calculate execution time | |
execution_time = time.time() - start_time | |
# Log final statistics | |
summary = [ | |
"\n=== Final Statistics ===", | |
f"Total questions processed: {total_rows}", | |
f"Successful updates: {success_count}", | |
f"Failed updates: {failure_count}", | |
f"Total cost: ${total_cost:.6f}", | |
f"Execution time: {execution_time:.2f} seconds", | |
"\nChanges by field:", | |
*[f"- {field}: {count}" for field, count in changes_by_field.items() if count > 0], | |
"\n=== End of Summary ===\n" | |
] | |
# Log to both console and summary file | |
for line in summary: | |
logger.info(line) | |
summary_logger.info(line) | |
except Exception as e: | |
error_msg = f"An unexpected error occurred: {str(e)}" | |
logger.error(error_msg) | |
summary_logger.error(error_msg) | |
if __name__ == "__main__": | |
main() | |