File size: 1,101 Bytes
e6e9cd7
 
08465c2
 
e6e9cd7
 
 
08465c2
e6e9cd7
08465c2
e6e9cd7
 
 
 
 
 
 
08465c2
e6e9cd7
 
 
 
 
 
08465c2
e6e9cd7
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os
from fastapi import FastAPI
from transformers import AutoTokenizer, T5ForConditionalGeneration

# 设置缓存路径(必须放在最前面)
os.environ["HF_HOME"] = "/app/.cache/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/app/.cache/huggingface"

app = FastAPI()

# 加载模型
try:
    tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-small")
    model = T5ForConditionalGeneration.from_pretrained("Salesforce/codet5-small")
except Exception as e:
    print(f"模型加载失败: {str(e)}")
    raise

@app.post("/analyze")
async def analyze_code(code: str):
    prompt = f"Analyze security vulnerabilities:\n{code}"
    
    inputs = tokenizer(prompt, return_tensors="pt", 
                      max_length=512, truncation=True)
    
    outputs = model.generate(
        inputs.input_ids,
        max_length=512,
        num_beams=5,
        early_stopping=True
    )
    
    return {
        "result": tokenizer.decode(outputs[0], skip_special_tokens=True)
    }

@app.get("/health")
def health_check():
    return {"status": "ok", "cache_path": os.environ["HF_HOME"]}