deberta_api / main.py
AISimplyExplained's picture
Rename app.py to main.py
8f80f5f verified
raw
history blame
1.2 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
import torch
class Guardrail:
def __init__(self):
tokenizer = AutoTokenizer.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")
model = AutoModelForSequenceClassification.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")
self.classifier = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
truncation=True,
max_length=512,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
def guard(self, prompt):
return self.classifier(prompt)
class TextPrompt(BaseModel):
prompt: str
app = FastAPI()
guardrail = Guardrail()
@app.post("/classify/")
def classify_text(text_prompt: TextPrompt):
try:
result = guardrail.guard(text_prompt.prompt)
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)