File size: 4,878 Bytes
417877c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import os
import uvicorn
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Dict, Union
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
# Definition of Pydantic data models
class ProblematicItem(BaseModel):
text: str
class ProblematicList(BaseModel):
problematics: List[str]
class PredictionResponse(BaseModel):
predicted_class: str
score: float
class PredictionsResponse(BaseModel):
results: List[Dict[str, Union[str, float]]]
# FastAPI Configuration
app = FastAPI(
title="Problematic Specificity Classification API",
description="This API classifies problematics using a fine-tuned model hosted on Hugging Face.",
version="1.0.0"
)
# Model environment variables
MODEL_NAME = os.getenv("MODEL_NAME", "votre-compte/votre-modele")
LABEL_0 = os.getenv("LABEL_0", "Classe A")
LABEL_1 = os.getenv("LABEL_1", "Classe B")
# Loading the model and tokenizer
tokenizer = None
model = None
def load_model():
global tokenizer, model
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
return True
except Exception as e:
print(f"Error loading model: {e}")
return False
# API state check route
@app.get("/")
def read_root():
return {"status": "ok", "model": MODEL_NAME}
# Route for checking model status
@app.get("/health")
def health_check():
global model, tokenizer
if model is None or tokenizer is None:
success = load_model()
if not success:
raise HTTPException(status_code=503, detail="Model not available")
return {"status": "ok", "model": MODEL_NAME}
# Route to predict a single problem at a time
@app.post("/predict", response_model=PredictionResponse)
def predict_single(item: ProblematicItem):
global model, tokenizer
if model is None or tokenizer is None:
success = load_model()
if not success:
raise HTTPException(status_code=503, detail="Model not available")
try:
# Tokenization
inputs = tokenizer(item.text, padding=True, truncation=True, return_tensors="pt")
# Prediction
with torch.no_grad():
outputs = model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_class = torch.argmax(probabilities, dim=1).item()
confidence_score = probabilities[0][predicted_class].item()
# Associate the correct label
predicted_label = LABEL_0 if predicted_class == 0 else LABEL_1
return PredictionResponse(predicted_class=predicted_label, score=confidence_score)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error during prediction: {str(e)}")
# Route for predicting several problems at once
@app.post("/predict-batch", response_model=PredictionsResponse)
def predict_batch(items: ProblematicList):
global model, tokenizer
if model is None or tokenizer is None:
success = load_model()
if not success:
raise HTTPException(status_code=503, detail="Model not available")
try:
results = []
# Batch processing
batch_size = 16
for i in range(0, len(items.problematics), batch_size):
batch_texts = items.problematics[i:i+batch_size]
# Tokenization
inputs = tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt")
# Prediction
with torch.no_grad():
outputs = model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_classes = torch.argmax(probabilities, dim=1).tolist()
confidence_scores = [probabilities[j][predicted_classes[j]].item() for j in range(len(predicted_classes))]
# Converting numerical predictions into labels
for j, (pred_class, score) in enumerate(zip(predicted_classes, confidence_scores)):
predicted_label = LABEL_0 if pred_class == 0 else LABEL_1
results.append({
"text": batch_texts[j],
"class": predicted_label,
"score": score
})
return PredictionsResponse(results=results)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error during prediction: {str(e)}")
# Model loading at startup
@app.on_event("startup")
async def startup_event():
load_model()
# Entry point for uvicorn
if __name__ == "__main__":
# Starting the server with uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
|