Spaces:
Sleeping
Sleeping
File size: 1,925 Bytes
52b1203 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import torch
from config import settings
from fastapi import FastAPI, Request, status
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, RedirectResponse
from itr.router import init_model, init_vectordb
from itr.router import router as router
app = FastAPI(title="Text-to-image Retrieval API")
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_headers=settings.CORS_HEADERS,
allow_credentials=True,
allow_methods=["*"],
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
# Get the original 'detail' list of errors
details = exc.errors()
error_details = []
for error in details:
error_details.append({"error": f"{error['msg']} {str(error['loc'])}"})
return JSONResponse(content={"message": error_details})
@app.on_event("startup")
async def startup_event():
init_vectordb(
index_file_path=settings.INDEX_FILE_PATH,
index_subframes_file_path=settings.INDEX_SUBFRAMES_FILE_PATH,
keyframes_groups_json_path=settings.KEYFRAMES_GROUPS_JSON_PATH,
subframes_groups_json_path=settings.SUBFRAMES_GROUPS_JSON_PATH,
)
device = (
"cuda" if settings.DEVICE == "cuda" and torch.cuda.is_available() else "cpu"
)
init_model(model_name=settings.MODEL_NAME, device=device)
@app.get("/", include_in_schema=False)
async def root() -> None:
return RedirectResponse("/docs")
@app.get("/health", status_code=status.HTTP_200_OK, tags=["health"])
async def perform_healthcheck() -> None:
return JSONResponse(content={"message": "success"})
app.include_router(router)
# Start API
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host=settings.HOST, port=settings.PORT, reload=True)
|