|
import os |
|
from pathlib import Path |
|
from fastapi import FastAPI, Request |
|
from fastapi.responses import JSONResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
import uvicorn |
|
from loguru import logger |
|
|
|
from iopaint.api import Api |
|
from iopaint.schema import ApiConfig, Device |
|
from iopaint.runtime import setup_model_dir, check_device, dump_environment_info |
|
from iopaint.const import DEFAULT_MODEL_DIR |
|
|
|
|
|
host = os.environ.get("IOPAINT_HOST", "0.0.0.0") |
|
port = int(os.environ.get("IOPAINT_PORT", "7860")) |
|
model = os.environ.get("IOPAINT_MODEL", "lama") |
|
|
|
|
|
model_dir_str = os.environ.get("IOPAINT_MODEL_DIR", "/app/models") |
|
|
|
device_str = os.environ.get("IOPAINT_DEVICE", "cpu") |
|
api_key = os.environ.get("IOPAINT_API_KEY", None) |
|
allowed_origins = os.environ.get("ALLOWED_ORIGINS", "*").split(",") |
|
|
|
|
|
model_dir = Path(model_dir_str) |
|
try: |
|
model_dir.mkdir(parents=True, exist_ok=True) |
|
logger.info(f"Successfully created model directory: {model_dir}") |
|
except Exception as e: |
|
logger.error(f"Failed to create model directory: {e}") |
|
|
|
model_dir = Path("/tmp/iopaint/models") |
|
model_dir.mkdir(parents=True, exist_ok=True) |
|
logger.info(f"Using alternative model directory: {model_dir}") |
|
|
|
device = check_device(Device(device_str)) |
|
dump_environment_info() |
|
|
|
logger.info(f"Starting API server with model: {model} on device: {device}") |
|
logger.info(f"Model directory: {model_dir}") |
|
logger.info(f"Allowed origins: {allowed_origins}") |
|
|
|
|
|
app = FastAPI(title="IOPaint API") |
|
|
|
|
|
config = ApiConfig( |
|
host=host, |
|
port=port, |
|
model=model, |
|
device=device, |
|
model_dir=model_dir, |
|
input=None, |
|
output_dir=None, |
|
low_mem=os.environ.get("IOPAINT_LOW_MEM", "true").lower() == "true", |
|
no_half=os.environ.get("IOPAINT_NO_HALF", "false").lower() == "true", |
|
cpu_offload=os.environ.get("IOPAINT_CPU_OFFLOAD", "false").lower() == "true", |
|
disable_nsfw=os.environ.get("IOPAINT_DISABLE_NSFW", "false").lower() == "true", |
|
) |
|
|
|
|
|
cors_options = { |
|
"allow_methods": ["*"], |
|
"allow_headers": ["*", "X-API-Key"], |
|
"allow_origins": allowed_origins, |
|
"allow_credentials": True, |
|
} |
|
app.add_middleware(CORSMiddleware, **cors_options) |
|
|
|
|
|
if api_key: |
|
@app.middleware("http") |
|
async def api_key_validation(request: Request, call_next): |
|
|
|
if request.method == "OPTIONS": |
|
return await call_next(request) |
|
|
|
req_api_key = request.headers.get("X-API-Key") |
|
if not req_api_key or req_api_key != api_key: |
|
return JSONResponse( |
|
status_code=401, |
|
content={"detail": "Invalid API key"} |
|
) |
|
return await call_next(request) |
|
|
|
|
|
api = Api(app, config) |
|
|
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host=host, port=port) |