PolyakovK commited on
Commit
caf4318
·
verified ·
1 Parent(s): 184659e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -44
app.py CHANGED
@@ -1,49 +1,42 @@
1
- from fastapi import FastAPI, HTTPException
2
- from pydantic import BaseModel
3
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
- import torch
5
- import numpy as np
6
 
7
- app = FastAPI()
 
 
 
 
 
 
8
 
9
- # Load model and tokenizer
10
- MODEL_NAME = "xlm-roberta-base"
11
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
12
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3)
13
-
14
- class EmailRequest(BaseModel):
15
- text: str
16
-
17
- class EmailResponse(BaseModel):
18
- category: int
19
- confidence: float
20
-
21
- LABELS = {
22
- 0: "Клиент хочет назначить встречу",
23
- 1: "Клиент не заинтересован / нет времени / отказывается",
24
- 2: "Клиент задаёт уточняющие вопросы"
25
- }
26
-
27
- @app.post("/classify", response_model=EmailResponse)
28
- async def classify_email(request: EmailRequest):
29
- try:
30
- # Tokenize the input text
31
- inputs = tokenizer(request.text, return_tensors="pt", truncation=True, max_length=512)
32
-
33
- # Get model predictions
34
- with torch.no_grad():
35
- outputs = model(**inputs)
36
- predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
37
-
38
- # Get the predicted class and confidence
39
- predicted_class = torch.argmax(predictions).item()
40
- confidence = predictions[0][predicted_class].item()
41
-
42
- return EmailResponse(category=predicted_class + 1, confidence=confidence)
43
 
44
- except Exception as e:
45
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
46
 
 
47
  if __name__ == "__main__":
48
- import uvicorn
49
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
 
 
1
+ from transformers import pipeline
 
 
 
 
2
 
3
+ def get_classifier():
4
+ classifier = pipeline(
5
+ "zero-shot-classification",
6
+ model="sberbank-ai/rugpt3small_based_on_gpt2",
7
+ framework="pt"
8
+ )
9
+ return classifier
10
 
11
+ def classify_email(text):
12
+ classifier = get_classifier()
13
+
14
+ candidate_labels = [
15
+ "Клиент хочет назначить встречу",
16
+ "Клиент не заинтересован или отказывается",
17
+ "Клиент задаёт уточняющие вопросы"
18
+ ]
19
+
20
+ result = classifier(
21
+ text,
22
+ candidate_labels,
23
+ hypothesis_template="Это письмо о том, что {}."
24
+ )
25
+
26
+ # Получаем индекс наиболее вероятной метки (0, 1 или 2)
27
+ label_index = result["labels"].index(result["labels"][0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ # Возвращаем категорию (1, 2 или 3) и уверенность
30
+ return {
31
+ "category": label_index + 1,
32
+ "confidence": result["scores"][label_index],
33
+ "label": result["labels"][0]
34
+ }
35
 
36
+ # Пример использования
37
  if __name__ == "__main__":
38
+ test_text = "Добрый день! Можно ли узнать подробнее о ваших услугах и ценах?"
39
+ result = classify_email(test_text)
40
+ print(f"Категория: {result['category']}")
41
+ print(f"Уверенность: {result['confidence']:.2f}")
42
+ print(f"Метка: {result['label']}")