adrienbrdne commited on
Commit
330e067
·
verified ·
1 Parent(s): 120d9c7

Upload scoring_specificity.py

Browse files
Files changed (1) hide show
  1. scoring_specificity.py +118 -0
scoring_specificity.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uvicorn
3
+ from fastapi import FastAPI, HTTPException
4
+ from pydantic import BaseModel
5
+ from typing import List, Dict, Union
6
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
7
+ import torch
8
+
9
+
10
+ # Definition of Pydantic data models
11
+ class ProblematicItem(BaseModel):
12
+ text: str
13
+
14
+ class ProblematicList(BaseModel):
15
+ problematics: List[str]
16
+
17
+ class PredictionResponse(BaseModel):
18
+ predicted_class: str
19
+ score: float
20
+
21
+ class PredictionsResponse(BaseModel):
22
+ results: List[Dict[str, Union[str, float]]]
23
+
24
+ # Model environment variables
25
+ MODEL_NAME = os.getenv("MODEL_NAME", "votre-compte/votre-modele")
26
+ LABEL_0 = os.getenv("LABEL_0", "Classe A")
27
+ LABEL_1 = os.getenv("LABEL_1", "Classe B")
28
+
29
+ # Loading the model and tokenizer
30
+ tokenizer = None
31
+ model = None
32
+
33
+ def load_model():
34
+ global tokenizer, model
35
+ try:
36
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
37
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
38
+ return True
39
+ except Exception as e:
40
+ print(f"Error loading model: {e}")
41
+ return False
42
+
43
+
44
+ def health_check():
45
+ global model, tokenizer
46
+ if model is None or tokenizer is None:
47
+ success = load_model()
48
+ if not success:
49
+ raise HTTPException(status_code=503, detail="Model not available")
50
+ return {"status": "ok", "model": MODEL_NAME}
51
+
52
+
53
+ def predict_single(item: ProblematicItem):
54
+ global model, tokenizer
55
+
56
+ if model is None or tokenizer is None:
57
+ success = load_model()
58
+ if not success:
59
+ print('Error loading the model.')
60
+
61
+ try:
62
+ # Tokenization
63
+ inputs = tokenizer(item.text, padding=True, truncation=True, return_tensors="pt")
64
+
65
+ # Prediction
66
+ with torch.no_grad():
67
+ outputs = model(**inputs)
68
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
69
+ predicted_class = torch.argmax(probabilities, dim=1).item()
70
+ confidence_score = probabilities[0][predicted_class].item()
71
+
72
+ # Associate the correct label
73
+ predicted_label = LABEL_0 if predicted_class == 0 else LABEL_1
74
+
75
+ return PredictionResponse(predicted_class=predicted_label, score=confidence_score)
76
+
77
+ except Exception as e:
78
+ print(f"Error during prediction: {str(e)}")
79
+
80
+ def predict_batch(items: ProblematicList):
81
+ global model, tokenizer
82
+
83
+ if model is None or tokenizer is None:
84
+ success = load_model()
85
+ if not success:
86
+ print("Model not available")
87
+
88
+ try:
89
+ results = []
90
+
91
+ # Batch processing
92
+ batch_size = 8
93
+ for i in range(0, len(items.problematics), batch_size):
94
+ batch_texts = items.problematics[i:i+batch_size]
95
+
96
+ # Tokenization
97
+ inputs = tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt")
98
+
99
+ # Prediction
100
+ with torch.no_grad():
101
+ outputs = model(**inputs)
102
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
103
+ predicted_classes = torch.argmax(probabilities, dim=1).tolist()
104
+ confidence_scores = [probabilities[j][predicted_classes[j]].item() for j in range(len(predicted_classes))]
105
+
106
+ # Converting numerical predictions into labels
107
+ for j, (pred_class, score) in enumerate(zip(predicted_classes, confidence_scores)):
108
+ predicted_label = LABEL_0 if pred_class == 0 else LABEL_1
109
+ results.append({
110
+ "text": batch_texts[j],
111
+ "class": predicted_label,
112
+ "score": score
113
+ })
114
+
115
+ return PredictionsResponse(results=results)
116
+
117
+ except Exception as e:
118
+ print(f"Error during prediction: {str(e)}")