Spaces:
Configuration error
Configuration error
import glob | |
import os | |
import random | |
import re | |
import shutil | |
import subprocess | |
import time | |
from datetime import datetime | |
from pathlib import Path | |
from typing import Optional | |
from litellm_proxy_extras._logging import logger | |
def str_to_bool(value: Optional[str]) -> bool: | |
if value is None: | |
return False | |
return value.lower() in ("true", "1", "t", "y", "yes") | |
class ProxyExtrasDBManager: | |
def _get_prisma_dir() -> str: | |
""" | |
Get the path to the migrations directory | |
Set os.environ["LITELLM_MIGRATION_DIR"] to a custom migrations directory, to support baselining db in read-only fs. | |
""" | |
custom_migrations_dir = os.getenv("LITELLM_MIGRATION_DIR") | |
pkg_migrations_dir = os.path.dirname(__file__) | |
if custom_migrations_dir: | |
# If migrations_dir exists, copy contents | |
if os.path.exists(custom_migrations_dir): | |
# Copy contents instead of directory itself | |
for item in os.listdir(pkg_migrations_dir): | |
src_path = os.path.join(pkg_migrations_dir, item) | |
dst_path = os.path.join(custom_migrations_dir, item) | |
if os.path.isdir(src_path): | |
shutil.copytree(src_path, dst_path, dirs_exist_ok=True) | |
else: | |
shutil.copy2(src_path, dst_path) | |
else: | |
# If directory doesn't exist, create it and copy everything | |
shutil.copytree(pkg_migrations_dir, custom_migrations_dir) | |
return custom_migrations_dir | |
return pkg_migrations_dir | |
def _create_baseline_migration(schema_path: str) -> bool: | |
"""Create a baseline migration for an existing database""" | |
prisma_dir = ProxyExtrasDBManager._get_prisma_dir() | |
prisma_dir_path = Path(prisma_dir) | |
init_dir = prisma_dir_path / "migrations" / "0_init" | |
# Create migrations/0_init directory | |
init_dir.mkdir(parents=True, exist_ok=True) | |
database_url = os.getenv("DATABASE_URL") | |
try: | |
# 1. Generate migration SQL file by comparing empty state to current db state | |
logger.info("Generating baseline migration...") | |
migration_file = init_dir / "migration.sql" | |
subprocess.run( | |
[ | |
"prisma", | |
"migrate", | |
"diff", | |
"--from-empty", | |
"--to-url", | |
database_url, | |
"--script", | |
], | |
stdout=open(migration_file, "w"), | |
check=True, | |
timeout=30, | |
) | |
# 3. Mark the migration as applied since it represents current state | |
logger.info("Marking baseline migration as applied...") | |
subprocess.run( | |
[ | |
"prisma", | |
"migrate", | |
"resolve", | |
"--applied", | |
"0_init", | |
], | |
check=True, | |
timeout=30, | |
) | |
return True | |
except subprocess.TimeoutExpired: | |
logger.warning( | |
"Migration timed out - the database might be under heavy load." | |
) | |
return False | |
except subprocess.CalledProcessError as e: | |
logger.warning( | |
f"Error creating baseline migration: {e}, {e.stderr}, {e.stdout}" | |
) | |
raise e | |
def _get_migration_names(migrations_dir: str) -> list: | |
"""Get all migration directory names from the migrations folder""" | |
migration_paths = glob.glob(f"{migrations_dir}/migrations/*/migration.sql") | |
logger.info(f"Found {len(migration_paths)} migrations at {migrations_dir}") | |
return [Path(p).parent.name for p in migration_paths] | |
def _roll_back_migration(migration_name: str): | |
"""Mark a specific migration as rolled back""" | |
subprocess.run( | |
["prisma", "migrate", "resolve", "--rolled-back", migration_name], | |
timeout=60, | |
check=True, | |
capture_output=True, | |
) | |
def _resolve_specific_migration(migration_name: str): | |
"""Mark a specific migration as applied""" | |
subprocess.run( | |
["prisma", "migrate", "resolve", "--applied", migration_name], | |
timeout=60, | |
check=True, | |
capture_output=True, | |
) | |
def _resolve_all_migrations(migrations_dir: str, schema_path: str): | |
""" | |
1. Compare the current database state to schema.prisma and generate a migration for the diff. | |
2. Run prisma migrate deploy to apply any pending migrations. | |
3. Mark all existing migrations as applied. | |
""" | |
database_url = os.getenv("DATABASE_URL") | |
diff_dir = ( | |
Path(migrations_dir) | |
/ "migrations" | |
/ f"{datetime.now().strftime('%Y%m%d%H%M%S')}_baseline_diff" | |
) | |
try: | |
diff_dir.mkdir(parents=True, exist_ok=True) | |
except Exception as e: | |
if "Permission denied" in str(e): | |
logger.warning( | |
f"Permission denied - {e}\nunable to baseline db. Set LITELLM_MIGRATION_DIR environment variable to a writable directory to enable migrations." | |
) | |
return | |
raise e | |
diff_sql_path = diff_dir / "migration.sql" | |
# 1. Generate migration SQL for the diff between DB and schema | |
try: | |
logger.info("Generating migration diff between DB and schema.prisma...") | |
with open(diff_sql_path, "w") as f: | |
subprocess.run( | |
[ | |
"prisma", | |
"migrate", | |
"diff", | |
"--from-url", | |
database_url, | |
"--to-schema-datamodel", | |
schema_path, | |
"--script", | |
], | |
check=True, | |
timeout=60, | |
stdout=f, | |
) | |
except subprocess.CalledProcessError as e: | |
logger.warning(f"Failed to generate migration diff: {e.stderr}") | |
except subprocess.TimeoutExpired: | |
logger.warning("Migration diff generation timed out.") | |
# check if the migration was created | |
if not diff_sql_path.exists(): | |
logger.warning("Migration diff was not created") | |
return | |
logger.info(f"Migration diff created at {diff_sql_path}") | |
# 2. Run prisma db execute to apply the migration | |
try: | |
logger.info("Running prisma db execute to apply the migration diff...") | |
result = subprocess.run( | |
[ | |
"prisma", | |
"db", | |
"execute", | |
"--file", | |
str(diff_sql_path), | |
"--schema", | |
schema_path, | |
], | |
timeout=60, | |
check=True, | |
capture_output=True, | |
text=True, | |
) | |
logger.info(f"prisma db execute stdout: {result.stdout}") | |
logger.info("✅ Migration diff applied successfully") | |
except subprocess.CalledProcessError as e: | |
logger.warning(f"Failed to apply migration diff: {e.stderr}") | |
except subprocess.TimeoutExpired: | |
logger.warning("Migration diff application timed out.") | |
# 3. Mark all migrations as applied | |
migration_names = ProxyExtrasDBManager._get_migration_names(migrations_dir) | |
logger.info(f"Resolving {len(migration_names)} migrations") | |
for migration_name in migration_names: | |
try: | |
logger.info(f"Resolving migration: {migration_name}") | |
subprocess.run( | |
["prisma", "migrate", "resolve", "--applied", migration_name], | |
timeout=60, | |
check=True, | |
capture_output=True, | |
text=True, | |
) | |
logger.debug(f"Resolved migration: {migration_name}") | |
except subprocess.CalledProcessError as e: | |
if "is already recorded as applied in the database." not in e.stderr: | |
logger.warning( | |
f"Failed to resolve migration {migration_name}: {e.stderr}" | |
) | |
def setup_database(use_migrate: bool = False) -> bool: | |
""" | |
Set up the database using either prisma migrate or prisma db push | |
Uses migrations from litellm-proxy-extras package | |
Args: | |
schema_path (str): Path to the Prisma schema file | |
use_migrate (bool): Whether to use prisma migrate instead of db push | |
Returns: | |
bool: True if setup was successful, False otherwise | |
""" | |
schema_path = ProxyExtrasDBManager._get_prisma_dir() + "/schema.prisma" | |
use_migrate = str_to_bool(os.getenv("USE_PRISMA_MIGRATE")) or use_migrate | |
for attempt in range(4): | |
original_dir = os.getcwd() | |
migrations_dir = ProxyExtrasDBManager._get_prisma_dir() | |
os.chdir(migrations_dir) | |
try: | |
if use_migrate: | |
logger.info("Running prisma migrate deploy") | |
try: | |
# Set migrations directory for Prisma | |
result = subprocess.run( | |
["prisma", "migrate", "deploy"], | |
timeout=60, | |
check=True, | |
capture_output=True, | |
text=True, | |
) | |
logger.info(f"prisma migrate deploy stdout: {result.stdout}") | |
logger.info("prisma migrate deploy completed") | |
return True | |
except subprocess.CalledProcessError as e: | |
logger.info(f"prisma db error: {e.stderr}, e: {e.stdout}") | |
if "P3009" in e.stderr: | |
# Extract the failed migration name from the error message | |
migration_match = re.search( | |
r"`(\d+_.*)` migration", e.stderr | |
) | |
if migration_match: | |
failed_migration = migration_match.group(1) | |
logger.info( | |
f"Found failed migration: {failed_migration}, marking as rolled back" | |
) | |
# Mark the failed migration as rolled back | |
subprocess.run( | |
[ | |
"prisma", | |
"migrate", | |
"resolve", | |
"--rolled-back", | |
failed_migration, | |
], | |
timeout=60, | |
check=True, | |
capture_output=True, | |
text=True, | |
) | |
logger.info( | |
f"✅ Migration {failed_migration} marked as rolled back... retrying" | |
) | |
elif ( | |
"P3005" in e.stderr | |
and "database schema is not empty" in e.stderr | |
): | |
logger.info( | |
"Database schema is not empty, creating baseline migration" | |
) | |
ProxyExtrasDBManager._create_baseline_migration(schema_path) | |
logger.info( | |
"Baseline migration created, resolving all migrations" | |
) | |
ProxyExtrasDBManager._resolve_all_migrations( | |
migrations_dir, schema_path | |
) | |
logger.info("✅ All migrations resolved.") | |
return True | |
elif ( | |
"P3018" in e.stderr | |
): # PostgreSQL error code for duplicate column | |
logger.info( | |
"Migration already exists, resolving specific migration" | |
) | |
# Extract the migration name from the error message | |
migration_match = re.search( | |
r"Migration name: (\d+_.*)", e.stderr | |
) | |
if migration_match: | |
migration_name = migration_match.group(1) | |
logger.info(f"Rolling back migration {migration_name}") | |
ProxyExtrasDBManager._roll_back_migration( | |
migration_name | |
) | |
logger.info( | |
f"Resolving migration {migration_name} that failed due to existing columns" | |
) | |
ProxyExtrasDBManager._resolve_specific_migration( | |
migration_name | |
) | |
logger.info("✅ Migration resolved.") | |
else: | |
# Use prisma db push with increased timeout | |
subprocess.run( | |
["prisma", "db", "push", "--accept-data-loss"], | |
timeout=60, | |
check=True, | |
) | |
return True | |
except subprocess.TimeoutExpired: | |
logger.info(f"Attempt {attempt + 1} timed out") | |
time.sleep(random.randrange(5, 15)) | |
except subprocess.CalledProcessError as e: | |
attempts_left = 3 - attempt | |
retry_msg = ( | |
f" Retrying... ({attempts_left} attempts left)" | |
if attempts_left > 0 | |
else "" | |
) | |
logger.info(f"The process failed to execute. Details: {e}.{retry_msg}") | |
time.sleep(random.randrange(5, 15)) | |
finally: | |
os.chdir(original_dir) | |
pass | |
return False | |