Jae-Won Chung
Make one model selectable by user (#23)
7ed0b8b unverified
import os
import json
import uvicorn
from pydantic import BaseSettings
from fastapi import FastAPI, Depends
from fastapi.responses import StreamingResponse
from fastapi.exceptions import HTTPException
from text_generation.errors import OverloadedError, UnknownError, ValidationError
from spitfight.log import get_logger, init_queued_root_logger, shutdown_queued_root_loggers
from spitfight.colosseum.common import (
COLOSSEUM_MODELS_ROUTE,
COLOSSEUM_PROMPT_ROUTE,
COLOSSEUM_RESP_VOTE_ROUTE,
COLOSSEUM_ENERGY_VOTE_ROUTE,
COLOSSEUM_HEALTH_ROUTE,
ModelsResponse,
PromptRequest,
ResponseVoteRequest,
ResponseVoteResponse,
EnergyVoteRequest,
EnergyVoteResponse,
)
from spitfight.colosseum.controller.controller import (
Controller,
init_global_controller,
get_global_controller,
)
from spitfight.utils import prepend_generator
class ControllerConfig(BaseSettings):
"""Controller settings automatically loaded from environment variables."""
# Controller
background_task_interval: int = 300
max_num_req_states: int = 10000
req_state_expiration_time: int = 600
compose_files: list[str] = ["deployment/docker-compose-0.yaml", "deployment/docker-compose-1.yaml"]
# Logging
log_dir: str = "/logs"
controller_log_file: str = "controller.log"
request_log_file: str = "requests.log"
uvicorn_log_file: str = "uvicorn.log"
# Generation
max_new_tokens: int = 512
do_sample: bool = True
temperature: float = 1.0
repetition_penalty: float = 1.0
top_k: int = 50
top_p: float = 0.95
app = FastAPI()
settings = ControllerConfig()
logger = get_logger("spitfight.colosseum.controller.router")
@app.on_event("startup")
async def startup_event():
init_queued_root_logger("uvicorn", os.path.join(settings.log_dir, settings.uvicorn_log_file))
init_queued_root_logger("spitfight.colosseum.controller", os.path.join(settings.log_dir, settings.controller_log_file))
init_queued_root_logger("colosseum_requests", os.path.join(settings.log_dir, settings.request_log_file))
init_global_controller(settings)
@app.on_event("shutdown")
async def shutdown_event():
get_global_controller().shutdown()
shutdown_queued_root_loggers()
@app.get(COLOSSEUM_MODELS_ROUTE, response_model=ModelsResponse)
async def models(controller: Controller = Depends(get_global_controller)):
return ModelsResponse(available_models=controller.get_available_models())
@app.post(COLOSSEUM_PROMPT_ROUTE)
async def prompt(
request: PromptRequest,
controller: Controller = Depends(get_global_controller),
):
generator = controller.prompt(
request.request_id,
request.prompt,
request.model_index,
request.model_preference,
)
# First try to get the first token in order to catch TGI errors.
try:
first_token = await generator.__anext__()
except OverloadedError:
name = controller.request_states[request.request_id].model_names[request.model_index]
logger.warning("Model %s is overloaded. Failed request: %s", name, repr(request))
raise HTTPException(status_code=429, detail="Model overloaded. Pleaes try again later.")
except ValidationError as e:
logger.info("TGI returned validation error: %s. Failed request: %s", str(e), repr(request))
raise HTTPException(status_code=422, detail=str(e))
except StopAsyncIteration:
logger.info("TGI returned empty response. Failed request: %s", repr(request))
return StreamingResponse(
iter([json.dumps("*The model generated an empty response.*").encode() + b"\0"]),
)
except UnknownError as e:
logger.error("TGI returned unknown error: %s. Failed request: %s", str(e), repr(request))
raise HTTPException(status_code=500, detail=str(e))
return StreamingResponse(prepend_generator(first_token, generator))
@app.post(COLOSSEUM_RESP_VOTE_ROUTE, response_model=ResponseVoteResponse)
async def response_vote(
request: ResponseVoteRequest,
controller: Controller = Depends(get_global_controller),
):
if (state := controller.response_vote(request.request_id, request.victory_index)) is None:
raise HTTPException(status_code=410, detail="Colosseum battle session timeout expired.")
return ResponseVoteResponse(
energy_consumptions=state.energy_consumptions,
model_names=state.model_names,
)
@app.post(COLOSSEUM_ENERGY_VOTE_ROUTE, response_model=EnergyVoteResponse)
async def energy_vote(
request: EnergyVoteRequest,
controller: Controller = Depends(get_global_controller),
):
if (state := controller.energy_vote(request.request_id, request.is_worth)) is None:
raise HTTPException(status_code=410, detail="Colosseum battle session timeout expired.")
return EnergyVoteResponse(model_names=state.model_names)
@app.get(COLOSSEUM_HEALTH_ROUTE)
async def health():
return "OK"
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", log_config=None)