File size: 2,780 Bytes
f228a1c
 
 
 
99351b6
 
f228a1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14c8502
 
 
 
 
 
 
 
 
 
 
 
 
f228a1c
 
 
 
 
14c8502
 
 
 
 
 
 
f228a1c
 
99351b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f228a1c
 
bc21776
f228a1c
 
 
14c8502
 
 
 
f228a1c
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
import torch
from detoxify import Detoxify



class Guardrail:
    def __init__(self):
        tokenizer = AutoTokenizer.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")
        model = AutoModelForSequenceClassification.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")

        self.classifier = pipeline(
            "text-classification",
            model=model,
            tokenizer=tokenizer,
            truncation=True,
            max_length=512,
            device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
        )

    def guard(self, prompt):
        return self.classifier(prompt)

    def determine_level(self, label, score):
        if label == "SAFE":
            return 0, "safe"
        else:
            if score > 0.9:
                return 4, "high"
            elif score > 0.75:
                return 3, "medium"
            elif score > 0.5:
                return 2, "low"
            else:
                return 1, "very low"


class TextPrompt(BaseModel):
    prompt: str


class ClassificationResult(BaseModel):
    label: str
    score: float
    level: int
    severity_label: str


app = FastAPI()
guardrail = Guardrail()
toxicity_classifier = Detoxify('original')

class ToxicityResult(BaseModel):
    toxicity: float
    severe_toxicity: float
    obscene: float
    threat: float
    insult: float
    identity_attack: float

@app.post("/api/models/toxicity/classify", response_model=ToxicityResult)
def classify_toxicity(text_prompt: TextPrompt):
    try:
        result = toxicity_classifier.predict(text_prompt.prompt)
        return {
            "toxicity": result['toxicity'],
            "severe_toxicity": result['severe_toxicity'],
            "obscene": result['obscene'],
            "threat": result['threat'],
            "insult": result['insult'],
            "identity_attack": result['identity_attack']
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/api/models/PromptInjection/classify", response_model=ClassificationResult)
def classify_text(text_prompt: TextPrompt):
    try:
        result = guardrail.guard(text_prompt.prompt)
        label = result[0]['label']
        score = result[0]['score']
        level, severity_label = guardrail.determine_level(label, score)
        return {"label": label, "score": score, "level": level, "severity_label": severity_label}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=8000)