Update app.py
Browse files
app.py
CHANGED
@@ -7,8 +7,6 @@ import rdflib
|
|
7 |
from rdflib.plugins.sparql.parser import parseQuery
|
8 |
from huggingface_hub import InferenceClient
|
9 |
import re
|
10 |
-
import torch
|
11 |
-
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
|
12 |
# ---------------------------------------------------------------------------
|
13 |
# CONFIGURAZIONE LOGGING
|
14 |
# ---------------------------------------------------------------------------
|
@@ -18,22 +16,6 @@ logging.basicConfig(
|
|
18 |
handlers=[logging.FileHandler("app.log"), logging.StreamHandler()]
|
19 |
)
|
20 |
logger = logging.getLogger(__name__)
|
21 |
-
|
22 |
-
|
23 |
-
# Determina il device (GPU se disponibile, altrimenti CPU)
|
24 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
25 |
-
logger.info(f"Device per il classificatore: {device}")
|
26 |
-
|
27 |
-
# Carica il modello e il tokenizer del classificatore fine-tuned
|
28 |
-
try:
|
29 |
-
logger.info("Caricamento del modello di classificazione fine-tuned da 'finetuned-bert-model'.")
|
30 |
-
classifier_model = DistilBertForSequenceClassification.from_pretrained("finetuned-bert-model")
|
31 |
-
classifier_tokenizer = DistilBertTokenizer.from_pretrained("finetuned-bert-model")
|
32 |
-
classifier_model.to(device)
|
33 |
-
logger.info("Modello di classificazione caricato correttamente.")
|
34 |
-
except Exception as e:
|
35 |
-
logger.error(f"Errore nel caricamento del modello di classificazione: {e}")
|
36 |
-
classifier_model = None
|
37 |
explanation_dict = {}
|
38 |
# ---------------------------------------------------------------------------
|
39 |
# COSTANTI / CHIAVI / MODELLI
|
@@ -402,32 +384,6 @@ def assistant_endpoint(req: AssistantRequest):
|
|
402 |
max_tokens = req.max_tokens
|
403 |
temperature = req.temperature
|
404 |
logger.debug(f"Parametri utente: message='{user_message}', max_tokens={max_tokens}, temperature={temperature}")
|
405 |
-
# -------------------------------
|
406 |
-
# CLASSIFICAZIONE DEL TESTO RICEVUTO
|
407 |
-
# -------------------------------
|
408 |
-
if classifier_model is not None:
|
409 |
-
try:
|
410 |
-
# Prepara l'input per il modello di classificazione
|
411 |
-
inputs = classifier_tokenizer(user_message, return_tensors="pt", truncation=True, padding=True)
|
412 |
-
inputs = {k: v.to(device) for k, v in inputs.items()}
|
413 |
-
|
414 |
-
# Disattiva il calcolo del gradiente per velocizzare l'inferenza
|
415 |
-
with torch.no_grad():
|
416 |
-
outputs = classifier_model(**inputs)
|
417 |
-
logits = outputs.logits
|
418 |
-
pred = torch.argmax(logits, dim=1).item()
|
419 |
-
|
420 |
-
# Mappa l'etichetta numerica a una stringa (modifica secondo la tua logica)
|
421 |
-
label_mapping = {0: "NON PERTINENTE", 1: "PERTINENTE"}
|
422 |
-
classification_result = label_mapping.get(pred, f"Etichetta {pred}")
|
423 |
-
logger.info(f"[Classificazione] La domanda classificata come: {classification_result}")
|
424 |
-
explanation_dict['classification'] = f"Risultato classificazione: {classification_result}"
|
425 |
-
except Exception as e:
|
426 |
-
logger.error(f"Errore durante la classificazione della domanda: {e}")
|
427 |
-
explanation_dict['classification'] = f"Errore classificazione: {e}"
|
428 |
-
else:
|
429 |
-
logger.warning("Modello di classificazione non disponibile.")
|
430 |
-
explanation_dict['classification'] = "Modello di classificazione non disponibile."
|
431 |
# -----------------------------------------------------------------------
|
432 |
# STEP 1: Generazione della query SPARQL
|
433 |
# -----------------------------------------------------------------------
|
|
|
7 |
from rdflib.plugins.sparql.parser import parseQuery
|
8 |
from huggingface_hub import InferenceClient
|
9 |
import re
|
|
|
|
|
10 |
# ---------------------------------------------------------------------------
|
11 |
# CONFIGURAZIONE LOGGING
|
12 |
# ---------------------------------------------------------------------------
|
|
|
16 |
handlers=[logging.FileHandler("app.log"), logging.StreamHandler()]
|
17 |
)
|
18 |
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
explanation_dict = {}
|
20 |
# ---------------------------------------------------------------------------
|
21 |
# COSTANTI / CHIAVI / MODELLI
|
|
|
384 |
max_tokens = req.max_tokens
|
385 |
temperature = req.temperature
|
386 |
logger.debug(f"Parametri utente: message='{user_message}', max_tokens={max_tokens}, temperature={temperature}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
387 |
# -----------------------------------------------------------------------
|
388 |
# STEP 1: Generazione della query SPARQL
|
389 |
# -----------------------------------------------------------------------
|