File size: 1,951 Bytes
f228a1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14c8502
 
 
 
 
 
 
 
 
 
 
 
 
f228a1c
 
 
 
 
14c8502
 
 
 
 
 
 
f228a1c
 
 
 
bc21776
f228a1c
 
 
14c8502
 
 
 
f228a1c
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
import torch


class Guardrail:
    def __init__(self):
        tokenizer = AutoTokenizer.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")
        model = AutoModelForSequenceClassification.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")

        self.classifier = pipeline(
            "text-classification",
            model=model,
            tokenizer=tokenizer,
            truncation=True,
            max_length=512,
            device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
        )

    def guard(self, prompt):
        return self.classifier(prompt)

    def determine_level(self, label, score):
        if label == "SAFE":
            return 0, "safe"
        else:
            if score > 0.9:
                return 4, "high"
            elif score > 0.75:
                return 3, "medium"
            elif score > 0.5:
                return 2, "low"
            else:
                return 1, "very low"


class TextPrompt(BaseModel):
    prompt: str


class ClassificationResult(BaseModel):
    label: str
    score: float
    level: int
    severity_label: str


app = FastAPI()
guardrail = Guardrail()


@app.post("/api/models/PromptInjection/classify", response_model=ClassificationResult)
def classify_text(text_prompt: TextPrompt):
    try:
        result = guardrail.guard(text_prompt.prompt)
        label = result[0]['label']
        score = result[0]['score']
        level, severity_label = guardrail.determine_level(label, score)
        return {"label": label, "score": score, "level": level, "severity_label": severity_label}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=8000)