from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch import numpy as np app = FastAPI() # Load model and tokenizer MODEL_NAME = "xlm-roberta-base" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3) class EmailRequest(BaseModel): text: str class EmailResponse(BaseModel): category: int confidence: float LABELS = { 0: "Клиент хочет назначить встречу", 1: "Клиент не заинтересован / нет времени / отказывается", 2: "Клиент задаёт уточняющие вопросы" } @app.post("/classify", response_model=EmailResponse) async def classify_email(request: EmailRequest): try: # Tokenize the input text inputs = tokenizer(request.text, return_tensors="pt", truncation=True, max_length=512) # Get model predictions with torch.no_grad(): outputs = model(**inputs) predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) # Get the predicted class and confidence predicted_class = torch.argmax(predictions).item() confidence = predictions[0][predicted_class].item() return EmailResponse(category=predicted_class + 1, confidence=confidence) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)