File size: 1,657 Bytes
1ea9c5f
 
 
 
 
3d07812
1ea9c5f
64f605f
1ea9c5f
 
 
 
3d07812
1ea9c5f
 
64f605f
1ea9c5f
 
 
b4e675c
1ea9c5f
 
 
 
 
b4e675c
1ea9c5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4e675c
 
1ea9c5f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import numpy as np

app = FastAPI()

# Load model and tokenizer
MODEL_NAME = "xlm-roberta-base"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3)

class EmailRequest(BaseModel):
    text: str

class EmailResponse(BaseModel):
    category: int
    confidence: float

LABELS = {
    0: "Клиент хочет назначить встречу",
    1: "Клиент не заинтересован / нет времени / отказывается",
    2: "Клиент задаёт уточняющие вопросы"
}

@app.post("/classify", response_model=EmailResponse)
async def classify_email(request: EmailRequest):
    try:
        # Tokenize the input text
        inputs = tokenizer(request.text, return_tensors="pt", truncation=True, max_length=512)
        
        # Get model predictions
        with torch.no_grad():
            outputs = model(**inputs)
            predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
        
        # Get the predicted class and confidence
        predicted_class = torch.argmax(predictions).item()
        confidence = predictions[0][predicted_class].item()
        
        return EmailResponse(category=predicted_class + 1, confidence=confidence)
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)