bias-detector / api.py
mjwagerman's picture
moved files and restructured
849684c
raw
history blame
1.47 kB
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
app = FastAPI()
# Allow your frontend (adjust if deployed elsewhere)
origins = [
"http://localhost:3000", # Next.js frontend
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load your trained model (swap with your own if needed)
MODEL_PATH = "./bert-bias-detector/checkpoint-4894" # or wherever your model is saved bert-bias-detector\checkpoint-4894
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Input format
class InputText(BaseModel):
text: str
@app.post("/predict")
async def predict_text(payload: InputText):
inputs = tokenizer(payload.text, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = logits.softmax(dim=-1)[0].tolist()
labels = ["Left", "Center", "Right"]
predicted_label = labels[torch.argmax(logits).item()]
return {
"bias_scores": probs,
"predicted": predicted_label
}