File size: 3,821 Bytes
45f7e41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6948ec2
 
 
45f7e41
f7bea85
 
 
45f7e41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d630eb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
from pydantic import BaseModel
from typing import List, Dict, Union
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch


# Definition of Pydantic data models
class ProblematicItem(BaseModel):
    text: str

class ProblematicList(BaseModel):
    problematics: List[str]

class PredictionResponse(BaseModel):
    predicted_class: str
    score: float

class PredictionsResponse(BaseModel):
    results: List[Dict[str, Union[str, float]]]

# Model environment variables
MODEL_NAME = os.getenv("MODEL_NAME")
LABEL_0 = os.getenv("LABEL_0")
LABEL_1 = os.getenv("LABEL_1")

if not MODEL_NAME:
    raise ValueError("Environment variable MODEL_NAME is not set.")

# Loading the model and tokenizer
tokenizer = None
model = None


def load_model():
    global tokenizer, model
    try:
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
        return True
    except Exception as e:
        print(f"Error loading model: {e}")
        return False


def health_check():
    global model, tokenizer
    if model is None or tokenizer is None:
        success = load_model()
        if not success:
            print("Model not available")
    return {"status": "ok", "model": MODEL_NAME}


def predict_single(item: ProblematicItem):
    global model, tokenizer
    
    if model is None or tokenizer is None:
        success = load_model()
        if not success:
            print('Error loading the model.')
    
    try:
        # Tokenization
        inputs = tokenizer(item.text, padding=True, truncation=True, return_tensors="pt")
        
        # Prediction
        with torch.no_grad():
            outputs = model(**inputs)
            probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
            predicted_class = torch.argmax(probabilities, dim=1).item()
            confidence_score = probabilities[0][predicted_class].item()
        
        # Associate the correct label
        predicted_label = LABEL_0 if predicted_class == 0 else LABEL_1
        
        return PredictionResponse(predicted_class=predicted_label, score=confidence_score)
    
    except Exception as e:
        print(f"Error during prediction: {str(e)}")

def predict_batch(items: ProblematicList):
    global model, tokenizer
    
    if model is None or tokenizer is None:
        success = load_model()
        if not success:
            print("Model not available")
    
    try:
        results = []
        
        # Batch processing
        batch_size = 8
        for i in range(0, len(items.problematics), batch_size):
            batch_texts = items.problematics[i:i+batch_size]
            
            # Tokenization
            inputs = tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt")
            
            # Prediction
            with torch.no_grad():
                outputs = model(**inputs)
                probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
                predicted_classes = torch.argmax(probabilities, dim=1).tolist()
                confidence_scores = [probabilities[j][predicted_classes[j]].item() for j in range(len(predicted_classes))]
            
            # Converting numerical predictions into labels
            for j, (pred_class, score) in enumerate(zip(predicted_classes, confidence_scores)):
                predicted_label = LABEL_0 if pred_class == 0 else LABEL_1
                results.append({
                    "text": batch_texts[j],
                    "class": predicted_label,
                    "score": score
                })
        
        return PredictionsResponse(results=results)
    
    except Exception as e:
        print(f"Error during prediction: {str(e)}")