Giulio Rossi commited on
Commit
9f50754
·
verified ·
1 Parent(s): 806a1b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -5,7 +5,7 @@ from PIL import Image
5
  import numpy as np
6
  import gradio as gr
7
 
8
- # Definisci il modello e il numero di classi
9
  class PretrainedModel(nn.Module):
10
  def __init__(self, num_classes=19):
11
  super(PretrainedModel, self).__init__()
@@ -27,11 +27,12 @@ class PretrainedModel(nn.Module):
27
  # Crea un'istanza del modello
28
  model = PretrainedModel(num_classes=19)
29
 
30
- # Carica i pesi
31
- model.load_state_dict(torch.load('model_v11.pt', map_location=torch.device('cpu')))
 
32
  model.eval() # Imposta il modello in modalità valutazione
33
 
34
- # Trasformazioni
35
  preprocess = transforms.Compose([
36
  transforms.Resize((224, 224)),
37
  transforms.ToTensor(),
@@ -49,10 +50,10 @@ def classify_image(img):
49
 
50
  return f"Class {predicted_class_index}, Confidence: {predicted_probability:.4f}"
51
 
52
- # Configura Gradio
53
  iface = gr.Interface(
54
  fn=classify_image,
55
- inputs=gr.inputs.Image(type="pil"),
56
  outputs="text"
57
  )
58
 
 
5
  import numpy as np
6
  import gradio as gr
7
 
8
+ # Definizione del modello pre-addestrato
9
  class PretrainedModel(nn.Module):
10
  def __init__(self, num_classes=19):
11
  super(PretrainedModel, self).__init__()
 
27
  # Crea un'istanza del modello
28
  model = PretrainedModel(num_classes=19)
29
 
30
+ # Carica i pesi con `weights_only=True` per evitare problemi di sicurezza
31
+ state_dict = torch.load('model_v11.pt', map_location=torch.device('cpu'), weights_only=True)
32
+ model.load_state_dict(state_dict)
33
  model.eval() # Imposta il modello in modalità valutazione
34
 
35
+ # Trasformazioni per l'immagine
36
  preprocess = transforms.Compose([
37
  transforms.Resize((224, 224)),
38
  transforms.ToTensor(),
 
50
 
51
  return f"Class {predicted_class_index}, Confidence: {predicted_probability:.4f}"
52
 
53
+ # Configura Gradio con l'API aggiornata
54
  iface = gr.Interface(
55
  fn=classify_image,
56
+ inputs=gr.Image(type="pil"),
57
  outputs="text"
58
  )
59