File size: 3,051 Bytes
4161fe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
from pathlib import Path
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from loguru import logger

from iopaint.api import Api
from iopaint.schema import ApiConfig, Device
from iopaint.runtime import setup_model_dir, check_device, dump_environment_info
from iopaint.const import DEFAULT_MODEL_DIR

# 从环境变量读取配置
host = os.environ.get("IOPAINT_HOST", "0.0.0.0")
port = int(os.environ.get("IOPAINT_PORT", "7860"))
model = os.environ.get("IOPAINT_MODEL", "lama")

# 修改模型目录路径,使用/app或/tmp目录
model_dir_str = os.environ.get("IOPAINT_MODEL_DIR", "/app/models")

device_str = os.environ.get("IOPAINT_DEVICE", "cpu")
api_key = os.environ.get("IOPAINT_API_KEY", None)
allowed_origins = os.environ.get("ALLOWED_ORIGINS", "*").split(",")

# 初始化目录和环境
model_dir = Path(model_dir_str)
try:
    model_dir.mkdir(parents=True, exist_ok=True)
    logger.info(f"Successfully created model directory: {model_dir}")
except Exception as e:
    logger.error(f"Failed to create model directory: {e}")
    # 如果失败,尝试使用/tmp目录
    model_dir = Path("/tmp/iopaint/models")
    model_dir.mkdir(parents=True, exist_ok=True)
    logger.info(f"Using alternative model directory: {model_dir}")

device = check_device(Device(device_str))
dump_environment_info()

logger.info(f"Starting API server with model: {model} on device: {device}")
logger.info(f"Model directory: {model_dir}")
logger.info(f"Allowed origins: {allowed_origins}")

# 初始化FastAPI
app = FastAPI(title="IOPaint API")

# 配置API
config = ApiConfig(
    host=host,
    port=port,
    model=model,
    device=device,
    model_dir=model_dir,
    input=None,
    output_dir=None,
    low_mem=os.environ.get("IOPAINT_LOW_MEM", "true").lower() == "true",
    no_half=os.environ.get("IOPAINT_NO_HALF", "false").lower() == "true",
    cpu_offload=os.environ.get("IOPAINT_CPU_OFFLOAD", "false").lower() == "true",
    disable_nsfw=os.environ.get("IOPAINT_DISABLE_NSFW", "false").lower() == "true",
)

# 配置CORS
cors_options = {
    "allow_methods": ["*"],
    "allow_headers": ["*", "X-API-Key"],
    "allow_origins": allowed_origins,
    "allow_credentials": True,
}
app.add_middleware(CORSMiddleware, **cors_options)

# API密钥验证(如果设置了)
if api_key:
    @app.middleware("http")
    async def api_key_validation(request: Request, call_next):
        # 如果是预检请求(OPTIONS),直接放行
        if request.method == "OPTIONS":
            return await call_next(request)
        
        req_api_key = request.headers.get("X-API-Key")
        if not req_api_key or req_api_key != api_key:
            return JSONResponse(
                status_code=401,
                content={"detail": "Invalid API key"}
            )
        return await call_next(request)

# 初始化API
api = Api(app, config)

# 直接启动服务
if __name__ == "__main__":
    uvicorn.run(app, host=host, port=port)