PolyakovK commited on
Commit
1ea9c5f
·
verified ·
1 Parent(s): 330ed6c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -33
app.py CHANGED
@@ -1,41 +1,49 @@
1
- from transformers import pipeline
2
- from flask import Flask, request, jsonify
 
 
 
3
 
4
- app = Flask(__name__)
5
 
6
- # Загружаем zero-shot классификатор
7
- classifier = pipeline("zero-shot-classification", model="typeform/distilbert-base-uncased-mnli")
 
 
8
 
9
- # Гипотезы
10
- LABELS = {
11
- "wants_meeting": "Клиент хочет назначить встречу или обсудить время",
12
- "not_interested": "Клиент не заинтересован во встрече или у него нет времени",
13
- "asking_questions": "Клиент задает уточняющие вопросы по теме"
14
- }
15
-
16
- @app.route("/analyze", methods=["POST"])
17
- def analyze():
18
- data = request.get_json()
19
- emails = data.get("emails", [])
20
 
21
- if not emails or not isinstance(emails, list):
22
- return jsonify({"error": "Field 'emails' must be a non-empty list"}), 400
 
23
 
24
- results = []
25
- for email in emails:
26
- prediction = classifier(email, list(LABELS.values()), multi_label=True)
27
- scored_labels = dict(zip(prediction["labels"], prediction["scores"]))
28
-
29
- # Вывод в формате {label: score}
30
- result = {
31
- "text": email,
32
- "intents": {
33
- key: round(scored_labels[label], 4) for key, label in LABELS.items()
34
- }
35
- }
36
- results.append(result)
37
 
38
- return jsonify(results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  if __name__ == "__main__":
41
- app.run(host="0.0.0.0", port=7860)
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ import torch
5
+ import numpy as np
6
 
7
+ app = FastAPI()
8
 
9
+ # Load model and tokenizer
10
+ MODEL_NAME = "xlm-roberta-base"
11
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
12
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3)
13
 
14
+ class EmailRequest(BaseModel):
15
+ text: str
 
 
 
 
 
 
 
 
 
16
 
17
+ class EmailResponse(BaseModel):
18
+ category: int
19
+ confidence: float
20
 
21
+ LABELS = {
22
+ 0: "Клиент хочет назначить встречу",
23
+ 1: "Клиент не заинтересован / нет времени / отказывается",
24
+ 2: "Клиент задаёт уточняющие вопросы"
25
+ }
 
 
 
 
 
 
 
 
26
 
27
+ @app.post("/classify", response_model=EmailResponse)
28
+ async def classify_email(request: EmailRequest):
29
+ try:
30
+ # Tokenize the input text
31
+ inputs = tokenizer(request.text, return_tensors="pt", truncation=True, max_length=512)
32
+
33
+ # Get model predictions
34
+ with torch.no_grad():
35
+ outputs = model(**inputs)
36
+ predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
37
+
38
+ # Get the predicted class and confidence
39
+ predicted_class = torch.argmax(predictions).item()
40
+ confidence = predictions[0][predicted_class].item()
41
+
42
+ return EmailResponse(category=predicted_class + 1, confidence=confidence)
43
+
44
+ except Exception as e:
45
+ raise HTTPException(status_code=500, detail=str(e))
46
 
47
  if __name__ == "__main__":
48
+ import uvicorn
49
+ uvicorn.run(app, host="0.0.0.0", port=8000)