PolyakovK's picture
Update app.py
1ea9c5f verified
raw
history blame
1.66 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import numpy as np
app = FastAPI()
# Load model and tokenizer
MODEL_NAME = "xlm-roberta-base"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3)
class EmailRequest(BaseModel):
text: str
class EmailResponse(BaseModel):
category: int
confidence: float
LABELS = {
0: "Клиент хочет назначить встречу",
1: "Клиент не заинтересован / нет времени / отказывается",
2: "Клиент задаёт уточняющие вопросы"
}
@app.post("/classify", response_model=EmailResponse)
async def classify_email(request: EmailRequest):
try:
# Tokenize the input text
inputs = tokenizer(request.text, return_tensors="pt", truncation=True, max_length=512)
# Get model predictions
with torch.no_grad():
outputs = model(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
# Get the predicted class and confidence
predicted_class = torch.argmax(predictions).item()
confidence = predictions[0][predicted_class].item()
return EmailResponse(category=predicted_class + 1, confidence=confidence)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)