Spaces:
Sleeping
Sleeping
import sys | |
from contextlib import asynccontextmanager | |
from fastapi import FastAPI, Depends, HTTPException | |
from fastapi.routing import APIRoute | |
from fastapi.middleware.cors import CORSMiddleware | |
import logging | |
from fastapi import FastAPI, Request, Depends, status, Response | |
from fastapi.responses import JSONResponse | |
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint | |
from starlette.responses import StreamingResponse | |
import json | |
import time | |
import requests | |
from pydantic import BaseModel, ConfigDict | |
from typing import Optional, List | |
from utils.utils import get_verified_user, get_current_user, get_admin_user | |
from config import SRC_LOG_LEVELS, ENV | |
from constants import MESSAGES | |
import os | |
log = logging.getLogger(__name__) | |
log.setLevel(SRC_LOG_LEVELS["LITELLM"]) | |
from config import ( | |
ENABLE_LITELLM, | |
ENABLE_MODEL_FILTER, | |
MODEL_FILTER_LIST, | |
DATA_DIR, | |
LITELLM_PROXY_PORT, | |
LITELLM_PROXY_HOST, | |
) | |
import warnings | |
warnings.simplefilter("ignore") | |
from litellm.utils import get_llm_provider | |
import asyncio | |
import subprocess | |
import yaml | |
async def lifespan(app: FastAPI): | |
log.info("startup_event") | |
# TODO: Check config.yaml file and create one | |
asyncio.create_task(start_litellm_background()) | |
yield | |
app = FastAPI(lifespan=lifespan) | |
origins = ["*"] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
LITELLM_CONFIG_DIR = f"{DATA_DIR}/litellm/config.yaml" | |
with open(LITELLM_CONFIG_DIR, "r") as file: | |
litellm_config = yaml.safe_load(file) | |
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER.value | |
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST.value | |
app.state.ENABLE = ENABLE_LITELLM | |
app.state.CONFIG = litellm_config | |
# Global variable to store the subprocess reference | |
background_process = None | |
CONFLICT_ENV_VARS = [ | |
# Uvicorn uses PORT, so LiteLLM might use it as well | |
"PORT", | |
# LiteLLM uses DATABASE_URL for Prisma connections | |
"DATABASE_URL", | |
] | |
async def run_background_process(command): | |
global background_process | |
log.info("run_background_process") | |
try: | |
# Log the command to be executed | |
log.info(f"Executing command: {command}") | |
# Filter environment variables known to conflict with litellm | |
env = {k: v for k, v in os.environ.items() if k not in CONFLICT_ENV_VARS} | |
# Execute the command and create a subprocess | |
process = await asyncio.create_subprocess_exec( | |
*command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env | |
) | |
background_process = process | |
log.info("Subprocess started successfully.") | |
# Capture STDERR for debugging purposes | |
stderr_output = await process.stderr.read() | |
stderr_text = stderr_output.decode().strip() | |
if stderr_text: | |
log.info(f"Subprocess STDERR: {stderr_text}") | |
# log.info output line by line | |
async for line in process.stdout: | |
log.info(line.decode().strip()) | |
# Wait for the process to finish | |
returncode = await process.wait() | |
log.info(f"Subprocess exited with return code {returncode}") | |
except Exception as e: | |
log.error(f"Failed to start subprocess: {e}") | |
raise # Optionally re-raise the exception if you want it to propagate | |
async def start_litellm_background(): | |
log.info("start_litellm_background") | |
# Command to run in the background | |
command = [ | |
"litellm", | |
"--port", | |
str(LITELLM_PROXY_PORT), | |
"--host", | |
LITELLM_PROXY_HOST, | |
"--telemetry", | |
"False", | |
"--config", | |
LITELLM_CONFIG_DIR, | |
] | |
await run_background_process(command) | |
async def shutdown_litellm_background(): | |
log.info("shutdown_litellm_background") | |
global background_process | |
if background_process: | |
background_process.terminate() | |
await background_process.wait() # Ensure the process has terminated | |
log.info("Subprocess terminated") | |
background_process = None | |
async def get_status(): | |
return {"status": True} | |
async def restart_litellm(): | |
""" | |
Endpoint to restart the litellm background service. | |
""" | |
log.info("Requested restart of litellm service.") | |
try: | |
# Shut down the existing process if it is running | |
await shutdown_litellm_background() | |
log.info("litellm service shutdown complete.") | |
# Restart the background service | |
asyncio.create_task(start_litellm_background()) | |
log.info("litellm service restart complete.") | |
return { | |
"status": "success", | |
"message": "litellm service restarted successfully.", | |
} | |
except Exception as e: | |
log.info(f"Error restarting litellm service: {e}") | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) | |
) | |
async def restart_litellm_handler(user=Depends(get_admin_user)): | |
return await restart_litellm() | |
async def get_config(user=Depends(get_admin_user)): | |
return app.state.CONFIG | |
class LiteLLMConfigForm(BaseModel): | |
general_settings: Optional[dict] = None | |
litellm_settings: Optional[dict] = None | |
model_list: Optional[List[dict]] = None | |
router_settings: Optional[dict] = None | |
model_config = ConfigDict(protected_namespaces=()) | |
async def update_config(form_data: LiteLLMConfigForm, user=Depends(get_admin_user)): | |
app.state.CONFIG = form_data.model_dump(exclude_none=True) | |
with open(LITELLM_CONFIG_DIR, "w") as file: | |
yaml.dump(app.state.CONFIG, file) | |
await restart_litellm() | |
return app.state.CONFIG | |
async def get_models(user=Depends(get_current_user)): | |
if app.state.ENABLE: | |
while not background_process: | |
await asyncio.sleep(0.1) | |
url = f"http://localhost:{LITELLM_PROXY_PORT}/v1" | |
r = None | |
try: | |
r = requests.request(method="GET", url=f"{url}/models") | |
r.raise_for_status() | |
data = r.json() | |
if app.state.ENABLE_MODEL_FILTER: | |
if user and user.role == "user": | |
data["data"] = list( | |
filter( | |
lambda model: model["id"] in app.state.MODEL_FILTER_LIST, | |
data["data"], | |
) | |
) | |
return data | |
except Exception as e: | |
log.exception(e) | |
error_detail = "Open WebUI: Server Connection Error" | |
if r is not None: | |
try: | |
res = r.json() | |
if "error" in res: | |
error_detail = f"External: {res['error']}" | |
except: | |
error_detail = f"External: {e}" | |
return { | |
"data": [ | |
{ | |
"id": model["model_name"], | |
"object": "model", | |
"created": int(time.time()), | |
"owned_by": "openai", | |
} | |
for model in app.state.CONFIG["model_list"] | |
], | |
"object": "list", | |
} | |
else: | |
return { | |
"data": [], | |
"object": "list", | |
} | |
async def get_model_list(user=Depends(get_admin_user)): | |
return {"data": app.state.CONFIG["model_list"]} | |
class AddLiteLLMModelForm(BaseModel): | |
model_name: str | |
litellm_params: dict | |
model_config = ConfigDict(protected_namespaces=()) | |
async def add_model_to_config( | |
form_data: AddLiteLLMModelForm, user=Depends(get_admin_user) | |
): | |
try: | |
get_llm_provider(model=form_data.model_name) | |
app.state.CONFIG["model_list"].append(form_data.model_dump()) | |
with open(LITELLM_CONFIG_DIR, "w") as file: | |
yaml.dump(app.state.CONFIG, file) | |
await restart_litellm() | |
return {"message": MESSAGES.MODEL_ADDED(form_data.model_name)} | |
except Exception as e: | |
print(e) | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) | |
) | |
class DeleteLiteLLMModelForm(BaseModel): | |
id: str | |
async def delete_model_from_config( | |
form_data: DeleteLiteLLMModelForm, user=Depends(get_admin_user) | |
): | |
app.state.CONFIG["model_list"] = [ | |
model | |
for model in app.state.CONFIG["model_list"] | |
if model["model_name"] != form_data.id | |
] | |
with open(LITELLM_CONFIG_DIR, "w") as file: | |
yaml.dump(app.state.CONFIG, file) | |
await restart_litellm() | |
return {"message": MESSAGES.MODEL_DELETED(form_data.id)} | |
async def proxy(path: str, request: Request, user=Depends(get_verified_user)): | |
body = await request.body() | |
url = f"http://localhost:{LITELLM_PROXY_PORT}" | |
target_url = f"{url}/{path}" | |
headers = {} | |
# headers["Authorization"] = f"Bearer {key}" | |
headers["Content-Type"] = "application/json" | |
r = None | |
try: | |
r = requests.request( | |
method=request.method, | |
url=target_url, | |
data=body, | |
headers=headers, | |
stream=True, | |
) | |
r.raise_for_status() | |
# Check if response is SSE | |
if "text/event-stream" in r.headers.get("Content-Type", ""): | |
return StreamingResponse( | |
r.iter_content(chunk_size=8192), | |
status_code=r.status_code, | |
headers=dict(r.headers), | |
) | |
else: | |
response_data = r.json() | |
return response_data | |
except Exception as e: | |
log.exception(e) | |
error_detail = "Open WebUI: Server Connection Error" | |
if r is not None: | |
try: | |
res = r.json() | |
if "error" in res: | |
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" | |
except: | |
error_detail = f"External: {e}" | |
raise HTTPException( | |
status_code=r.status_code if r else 500, detail=error_detail | |
) | |