Spaces:
Running
Running
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline | |
import torch | |
class Guardrail: | |
def __init__(self): | |
tokenizer = AutoTokenizer.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection") | |
model = AutoModelForSequenceClassification.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection") | |
self.classifier = pipeline( | |
"text-classification", | |
model=model, | |
tokenizer=tokenizer, | |
truncation=True, | |
max_length=512, | |
device=torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
) | |
def guard(self, prompt): | |
return self.classifier(prompt) | |
class TextPrompt(BaseModel): | |
prompt: str | |
app = FastAPI() | |
guardrail = Guardrail() | |
def classify_text(text_prompt: TextPrompt): | |
try: | |
result = guardrail.guard(text_prompt.prompt) | |
return result | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |