Spaces:
Running
Running
File size: 2,946 Bytes
1d117f2 c89e6e0 1d117f2 c89e6e0 b207b4c 1d117f2 c89e6e0 1d117f2 2b2ab5b a62b646 1d117f2 b207b4c af688eb 6780f80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
from typing import Optional
import torch
import weave
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
from transformers.pipelines.base import Pipeline
import wandb
from ..base import Guardrail
class PromptInjectionClassifierGuardrail(Guardrail):
"""
A guardrail that uses a pre-trained text-classification model to classify prompts
for potential injection attacks.
Args:
model_name (str): The name of the HuggingFace model or a WandB
checkpoint artifact path to use for classification.
"""
model_name: str = "ProtectAI/deberta-v3-base-prompt-injection-v2"
_classifier: Optional[Pipeline] = None
def model_post_init(self, __context):
if self.model_name.startswith("wandb://"):
api = wandb.Api()
artifact = api.artifact(self.model_name.removeprefix("wandb://"))
artifact_dir = artifact.download()
tokenizer = AutoTokenizer.from_pretrained(artifact_dir)
model = AutoModelForSequenceClassification.from_pretrained(artifact_dir)
else:
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
model = AutoModelForSequenceClassification.from_pretrained(self.model_name)
self._classifier = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
truncation=True,
max_length=512,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
@weave.op()
def classify(self, prompt: str):
return self._classifier(prompt)
@weave.op()
def guard(self, prompt: str):
"""
Analyzes the given prompt to determine if it is safe or potentially an injection attack.
This function uses a pre-trained text-classification model to classify the prompt.
It calls the `classify` method to get the classification result, which includes a label
and a confidence score. The function then calculates the confidence percentage and
returns a dictionary with two keys:
- "safe": A boolean indicating whether the prompt is safe (True) or an injection (False).
- "summary": A string summarizing the classification result, including the label and the
confidence percentage.
Args:
prompt (str): The input prompt to be classified.
Returns:
dict: A dictionary containing the safety status and a summary of the classification result.
"""
response = self.classify(prompt)
confidence_percentage = round(response[0]["score"] * 100, 2)
return {
"safe": response[0]["label"] != "INJECTION",
"summary": f"Prompt is deemed {response[0]['label']} with {confidence_percentage}% confidence.",
}
@weave.op()
def predict(self, prompt: str):
return self.guard(prompt)
|