File size: 1,745 Bytes
bac242d
fc986a8
 
a367787
b6af10b
bac242d
81e40a8
fc986a8
eb9892c
 
fc986a8
 
 
 
 
 
 
eb9892c
fc986a8
 
 
 
113ca35
fc986a8
bac242d
fc986a8
 
bac242d
fc986a8
b6af10b
 
113ca35
b6af10b
fc986a8
113ca35
b6af10b
bac242d
113ca35
 
fc986a8
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware  # 新增 CORS 支持
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import os

# === FastAPI 初始化 ===
app = FastAPI()

# 添加 CORS 中间件(关键步骤)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # 允许所有来源
    allow_methods=["*"],  # 允许所有 HTTP 方法
    allow_headers=["*"],  # 允许所有请求头
)

# === 模型加载 ===
os.environ["HF_HOME"] = "/app/.cache/huggingface"
model = AutoModelForSequenceClassification.from_pretrained("mrm8488/codebert-base-finetuned-detect-insecure-code")
tokenizer = AutoTokenizer.from_pretrained("mrm8488/codebert-base-finetuned-detect-insecure-code")

# === HTTP API 接口 ===
@app.post("/detect")
async def api_detect(code: str):
    """HTTP API 接口"""
    try:
        inputs = tokenizer(code[:2000], return_tensors="pt", truncation=True, max_length=512)
        with torch.no_grad():
            outputs = model(**inputs)
        label_id = outputs.logits.argmax().item()
        return {
            "label": int(label_id),  # 强制返回 0/1 数字
            "score": outputs.logits.softmax(dim=-1)[0][label_id].item()
        }
    except Exception as e:
        return {"error": str(e)}

# === Gradio 界面(可选)===
def gradio_predict(code: str):
    result = api_detect(code)
    return f"Prediction: {result['label']} (Confidence: {result['score']:.2f})"

gr_interface = gr.Interface(
    fn=gradio_predict,
    inputs=gr.Textbox(lines=10, placeholder="Paste code here..."),
    outputs="text",
    title="Code Security Detector"
)

app = gr.mount_gradio_app(app, gr_interface, path="/")