File size: 4,878 Bytes
417877c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
import uvicorn
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Dict, Union
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch


# Definition of Pydantic data models
class ProblematicItem(BaseModel):
    text: str

class ProblematicList(BaseModel):
    problematics: List[str]

class PredictionResponse(BaseModel):
    predicted_class: str
    score: float

class PredictionsResponse(BaseModel):
    results: List[Dict[str, Union[str, float]]]

# FastAPI Configuration
app = FastAPI(
    title="Problematic Specificity Classification API",
    description="This API classifies problematics using a fine-tuned model hosted on Hugging Face.",
    version="1.0.0"
)

# Model environment variables
MODEL_NAME = os.getenv("MODEL_NAME", "votre-compte/votre-modele")
LABEL_0 = os.getenv("LABEL_0", "Classe A")
LABEL_1 = os.getenv("LABEL_1", "Classe B")

# Loading the model and tokenizer
tokenizer = None
model = None

def load_model():
    global tokenizer, model
    try:
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
        return True
    except Exception as e:
        print(f"Error loading model: {e}")
        return False

# API state check route
@app.get("/")
def read_root():
    return {"status": "ok", "model": MODEL_NAME}

# Route for checking model status
@app.get("/health")
def health_check():
    global model, tokenizer
    if model is None or tokenizer is None:
        success = load_model()
        if not success:
            raise HTTPException(status_code=503, detail="Model not available")
    return {"status": "ok", "model": MODEL_NAME}

# Route to predict a single problem at a time
@app.post("/predict", response_model=PredictionResponse)
def predict_single(item: ProblematicItem):
    global model, tokenizer
    
    if model is None or tokenizer is None:
        success = load_model()
        if not success:
            raise HTTPException(status_code=503, detail="Model not available")
    
    try:
        # Tokenization
        inputs = tokenizer(item.text, padding=True, truncation=True, return_tensors="pt")
        
        # Prediction
        with torch.no_grad():
            outputs = model(**inputs)
            probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
            predicted_class = torch.argmax(probabilities, dim=1).item()
            confidence_score = probabilities[0][predicted_class].item()
        
        # Associate the correct label
        predicted_label = LABEL_0 if predicted_class == 0 else LABEL_1
        
        return PredictionResponse(predicted_class=predicted_label, score=confidence_score)
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error during prediction: {str(e)}")

# Route for predicting several problems at once
@app.post("/predict-batch", response_model=PredictionsResponse)
def predict_batch(items: ProblematicList):
    global model, tokenizer
    
    if model is None or tokenizer is None:
        success = load_model()
        if not success:
            raise HTTPException(status_code=503, detail="Model not available")
    
    try:
        results = []
        
        # Batch processing
        batch_size = 16
        for i in range(0, len(items.problematics), batch_size):
            batch_texts = items.problematics[i:i+batch_size]
            
            # Tokenization
            inputs = tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt")
            
            # Prediction
            with torch.no_grad():
                outputs = model(**inputs)
                probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
                predicted_classes = torch.argmax(probabilities, dim=1).tolist()
                confidence_scores = [probabilities[j][predicted_classes[j]].item() for j in range(len(predicted_classes))]
            
            # Converting numerical predictions into labels
            for j, (pred_class, score) in enumerate(zip(predicted_classes, confidence_scores)):
                predicted_label = LABEL_0 if pred_class == 0 else LABEL_1
                results.append({
                    "text": batch_texts[j],
                    "class": predicted_label,
                    "score": score
                })
        
        return PredictionsResponse(results=results)
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error during prediction: {str(e)}")

# Model loading at startup
@app.on_event("startup")
async def startup_event():
    load_model()

# Entry point for uvicorn
if __name__ == "__main__":
    # Starting the server with uvicorn
    uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)