Forrest99 commited on
Commit
b6af10b
·
verified ·
1 Parent(s): 81e40a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -32
app.py CHANGED
@@ -1,43 +1,37 @@
1
  from fastapi import FastAPI
2
- from pydantic import BaseModel
3
- from transformers import pipeline
4
  import os
5
 
6
- import os
7
- from pathlib import Path
8
 
 
 
9
 
10
- # 验证缓存目录可写
11
- cache_dir = Path(os.getenv("HF_HOME", ""))
12
- if not cache_dir.exists():
13
- cache_dir.mkdir(parents=True, exist_ok=True)
14
-
15
- test_file = cache_dir / "permission_test.txt"
16
  try:
17
- with open(test_file, "w") as f:
18
- f.write("test")
19
- os.remove(test_file)
20
- print("✅ Cache directory is writable")
21
  except Exception as e:
22
- print(f" Cache directory write failed: {str(e)}")
23
- raise
24
-
25
- # 正确加载模型(从缓存或下载)
26
- classifier = pipeline(
27
- "text-classification",
28
- model="mrm8488/codebert-base-finetuned-detect-insecure-code"
29
- )
30
-
31
- app = FastAPI()
32
-
33
- class CodeRequest(BaseModel):
34
- code: str # 输入参数定义
35
 
 
36
  @app.post("/detect")
37
- async def detect_insecure_code(request: CodeRequest):
38
  try:
39
- # 直接传递代码字符串到分类器
40
- result = classifier(request.code)
41
- return {"status": "success", "result": result[0]}
 
 
 
 
 
 
 
 
 
 
42
  except Exception as e:
43
- return {"status": "error", "message": str(e)}
 
1
  from fastapi import FastAPI
2
+ from puggingface import AutoTokenizer, AutoModelForSequenceClassification
3
+ import torch
4
  import os
5
 
6
+ # 1. 基础配置
7
+ app = FastAPI()
8
 
9
+ # 2. 强制设置缓存路径(解决权限问题)
10
+ os.environ["HF_HOME"] = "/app/.cache/huggingface"
11
 
12
+ # 3. 加载模型(自动缓存到指定路径)
 
 
 
 
 
13
  try:
14
+ model = AutoModelForSequenceClassification.from_pretrained("mrm8488/codebert-base-finetuned-detect-insecure-code")
15
+ tokenizer = AutoTokenizer.from_pretrained("mrm8488/codebert-base-finetuned-detect-insecure-code")
 
 
16
  except Exception as e:
17
+ raise RuntimeError(f"模型加载失败: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ # 4. 接口定义
20
  @app.post("/detect")
21
+ async def detect(code: str):
22
  try:
23
+ # 简单处理超长输入
24
+ if len(code) > 2000:
25
+ code = code[:2000]
26
+
27
+ inputs = tokenizer(code, return_tensors="pt", truncation=True)
28
+ with torch.no_grad():
29
+ outputs = model(**inputs)
30
+
31
+ return {
32
+ "label": model.config.id2label[outputs.logits.argmax().item()],
33
+ "score": outputs.logits.softmax(dim=-1).max().item()
34
+ }
35
+
36
  except Exception as e:
37
+ return {"error": str(e)}