Spaces:
Running
Running
from fastapi import FastAPI | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import torch | |
import os | |
import logging | |
# 初始化日志 | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger("CodeSecurityAPI") | |
# 强制设置缓存路径(解决权限问题) | |
os.environ["HF_HOME"] = "/app/.cache/huggingface" | |
# 加载模型 | |
try: | |
logger.info("Loading model...") | |
model = AutoModelForSequenceClassification.from_pretrained( | |
"mrm8488/codebert-base-finetuned-detect-insecure-code", | |
cache_dir=os.getenv("HF_HOME") | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"mrm8488/codebert-base-finetuned-detect-insecure-code", | |
cache_dir=os.getenv("HF_HOME") | |
) | |
logger.info("Model loaded successfully") | |
except Exception as e: | |
logger.error(f"Model load failed: {str(e)}") | |
raise RuntimeError("模型加载失败,请检查网络连接或模型路径") | |
app = FastAPI() | |
async def detect(code: str): | |
try: | |
# 输入处理(限制长度) | |
code = code[:2000] # 截断超长输入 | |
# 模型推理 | |
inputs = tokenizer(code, return_tensors="pt", truncation=True, max_length=512) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# 解析结果 | |
label_id = outputs.logits.argmax().item() | |
return { | |
"label": model.config.id2label[label_id], | |
"score": outputs.logits.softmax(dim=-1)[0][label_id].item() | |
} | |
except Exception as e: | |
return {"error": str(e)} | |
async def health(): | |
return {"status": "ok"} |