Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import torch | |
import numpy as np | |
app = FastAPI() | |
# Load model and tokenizer | |
MODEL_NAME = "xlm-roberta-base" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3) | |
class EmailRequest(BaseModel): | |
text: str | |
class EmailResponse(BaseModel): | |
category: int | |
confidence: float | |
LABELS = { | |
0: "Клиент хочет назначить встречу", | |
1: "Клиент не заинтересован / нет времени / отказывается", | |
2: "Клиент задаёт уточняющие вопросы" | |
} | |
async def classify_email(request: EmailRequest): | |
try: | |
# Tokenize the input text | |
inputs = tokenizer(request.text, return_tensors="pt", truncation=True, max_length=512) | |
# Get model predictions | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
# Get the predicted class and confidence | |
predicted_class = torch.argmax(predictions).item() | |
confidence = predictions[0][predicted_class].item() | |
return EmailResponse(category=predicted_class + 1, confidence=confidence) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |