brightlembo commited on
Commit
06c11b8
·
verified ·
1 Parent(s): 3daaf08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -51
app.py CHANGED
@@ -1,51 +1,81 @@
1
- import torch
2
- from torchvision import transforms
3
- from PIL import Image
4
- import gradio as gr
5
- import json
6
-
7
- # Charger les noms des classes
8
- with open("class_names.json", "r") as f:
9
- class_names = json.load(f)
10
-
11
- # Charger le modèle
12
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
- model = torch.load("efficientnet_b7_best.pth", map_location=device)
14
- model.eval() # Mode évaluation
15
-
16
- # Définir la taille de l'image
17
- image_size = (224, 224)
18
-
19
- # Transformation pour l'image
20
- class GrayscaleToRGB:
21
- def __call__(self, img):
22
- return img.convert("RGB")
23
-
24
- valid_test_transforms = transforms.Compose([
25
- transforms.Grayscale(num_output_channels=1),
26
- transforms.Resize(image_size),
27
- GrayscaleToRGB(), # Conversion en RGB
28
- transforms.ToTensor(),
29
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
30
- ])
31
-
32
- # Fonction de prédiction
33
- def predict_image(image):
34
- image_tensor = valid_test_transforms(image).unsqueeze(0).to(device)
35
- with torch.no_grad():
36
- outputs = model(image_tensor)
37
- _, predicted_class = torch.max(outputs, 1)
38
- predicted_label = class_names[predicted_class.item()]
39
- return predicted_label
40
-
41
- # Interface Gradio
42
- interface = gr.Interface(
43
- fn=predict_image,
44
- inputs=gr.Image(type="pil"),
45
- outputs="text",
46
- title="Prédiction d'images avec PyTorch",
47
- description="Chargez une image pour obtenir une prédiction de classe."
48
- )
49
-
50
- if __name__ == "__main__":
51
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from PIL import Image
4
+ import gradio as gr
5
+ import json
6
+ from torchvision.models import efficientnet_b7, EfficientNet_B7_Weights
7
+ import torch.nn as nn
8
+
9
+ # Charger les noms des classes
10
+ with open("class_names.json", "r") as f:
11
+ class_names = json.load(f)
12
+
13
+ # Charger l'architecture et les poids du modèle
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+ # Charger EfficientNet-B7 avec des poids pré-entraînés
17
+ weights = EfficientNet_B7_Weights.DEFAULT
18
+ base_model = efficientnet_b7(weights=weights)
19
+
20
+ # Adapter le modèle pour la classification (ajout d'une couche FC finale)
21
+ class CustomEfficientNet(nn.Module):
22
+ def __init__(self, base_model, num_classes):
23
+ super(CustomEfficientNet, self).__init__()
24
+ self.base = nn.Sequential(*list(base_model.children())[:-2]) # Couper la partie classification
25
+ self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
26
+ self.fc1 = nn.Linear(2560, 512) # Taille de sortie du dernier bloc
27
+ self.relu = nn.ReLU()
28
+ self.fc2 = nn.Linear(512, num_classes) # Nombre de classes pour la classification
29
+
30
+ def forward(self, x):
31
+ x = self.base(x)
32
+ x = self.global_avg_pool(x)
33
+ x = x.view(x.size(0), -1)
34
+ x = self.relu(self.fc1(x))
35
+ x = self.fc2(x)
36
+ return x
37
+
38
+ # Initialiser le modèle avec 3 classes (ajuste ce nombre selon ton cas)
39
+ num_classes = len(class_names) # Nombre de classes dans le fichier JSON
40
+ model = CustomEfficientNet(base_model, num_classes).to(device)
41
+
42
+ # Charger les poids dans le modèle
43
+ model.load_state_dict(torch.load("efficientnet_b7_best.pth", map_location=device))
44
+ model.eval() # Passer le modèle en mode évaluation
45
+
46
+ # Définir la taille de l'image
47
+ image_size = (224, 224)
48
+
49
+ # Transformation pour l'image
50
+ class GrayscaleToRGB:
51
+ def __call__(self, img):
52
+ return img.convert("RGB")
53
+
54
+ valid_test_transforms = transforms.Compose([
55
+ transforms.Grayscale(num_output_channels=1),
56
+ transforms.Resize(image_size),
57
+ GrayscaleToRGB(), # Conversion en RGB
58
+ transforms.ToTensor(),
59
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
60
+ ])
61
+
62
+ # Fonction de prédiction
63
+ def predict_image(image):
64
+ image_tensor = valid_test_transforms(image).unsqueeze(0).to(device)
65
+ with torch.no_grad():
66
+ outputs = model(image_tensor)
67
+ _, predicted_class = torch.max(outputs, 1)
68
+ predicted_label = class_names[predicted_class.item()]
69
+ return predicted_label
70
+
71
+ # Interface Gradio
72
+ interface = gr.Interface(
73
+ fn=predict_image,
74
+ inputs=gr.Image(type="pil"),
75
+ outputs="text",
76
+ title="Prédiction d'images avec PyTorch",
77
+ description="Chargez une image pour obtenir une prédiction de classe."
78
+ )
79
+
80
+ if __name__ == "__main__":
81
+ interface.launch()