codebertBase / app.py
Forrest99's picture
Update app.py
113ca35 verified
raw
history blame
1.66 kB
from fastapi import FastAPI
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import os
import logging
# 初始化日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("CodeSecurityAPI")
# 强制设置缓存路径(解决权限问题)
os.environ["HF_HOME"] = "/app/.cache/huggingface"
# 加载模型
try:
logger.info("Loading model...")
model = AutoModelForSequenceClassification.from_pretrained(
"mrm8488/codebert-base-finetuned-detect-insecure-code",
cache_dir=os.getenv("HF_HOME")
)
tokenizer = AutoTokenizer.from_pretrained(
"mrm8488/codebert-base-finetuned-detect-insecure-code",
cache_dir=os.getenv("HF_HOME")
)
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Model load failed: {str(e)}")
raise RuntimeError("模型加载失败,请检查网络连接或模型路径")
app = FastAPI()
@app.post("/detect")
async def detect(code: str):
try:
# 输入处理(限制长度)
code = code[:2000] # 截断超长输入
# 模型推理
inputs = tokenizer(code, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
# 解析结果
label_id = outputs.logits.argmax().item()
return {
"label": model.config.id2label[label_id],
"score": outputs.logits.softmax(dim=-1)[0][label_id].item()
}
except Exception as e:
return {"error": str(e)}
@app.get("/health")
async def health():
return {"status": "ok"}