codebertBase / app.py
Forrest99's picture
Update app.py
eb9892c verified
raw
history blame
1.8 kB
from fastapi import FastAPI
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import os
import logging
# 初始化日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("CodeSecurityAPI")
# 强制设置缓存路径
os.environ["HF_HOME"] = "/app/.cache/huggingface"
app = FastAPI()
# === 新增根路径响应 ===
@app.get("/")
async def read_root():
return {
"message": "欢迎使用代码安全检测API",
"endpoints": {
"detect": "POST /detect",
"health": "GET /health"
}
}
# === 加载模型(必须放在FastAPI实例之后) ===
try:
logger.info("Loading model...")
model = AutoModelForSequenceClassification.from_pretrained(
"mrm8488/codebert-base-finetuned-detect-insecure-code",
cache_dir=os.getenv("HF_HOME")
)
tokenizer = AutoTokenizer.from_pretrained(
"mrm8488/codebert-base-finetuned-detect-insecure-code",
cache_dir=os.getenv("HF_HOME")
)
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Model load failed: {str(e)}")
raise RuntimeError("模型加载失败,请检查网络连接或模型路径")
@app.post("/detect")
async def detect(code: str):
try:
code = code[:2000]
inputs = tokenizer(code, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
label_id = outputs.logits.argmax().item()
return {
"label": model.config.id2label[label_id],
"score": outputs.logits.softmax(dim=-1)[0][label_id].item()
}
except Exception as e:
return {"error": str(e)}
@app.get("/health")
async def health():
return {"status": "ok"}