brightlembo commited on
Commit
347cd1b
·
verified ·
1 Parent(s): 23747c9

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +51 -0
  2. efficientnet_b7_best.pth +3 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from PIL import Image
4
+ import gradio as gr
5
+ import json
6
+
7
+ # Charger les noms des classes
8
+ with open("class_names.json", "r") as f:
9
+ class_names = json.load(f)
10
+
11
+ # Charger le modèle
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ model = torch.load("efficientnet_b7_best.pth", map_location=device)
14
+ model.eval() # Mode évaluation
15
+
16
+ # Définir la taille de l'image
17
+ image_size = (224, 224)
18
+
19
+ # Transformation pour l'image
20
+ class GrayscaleToRGB:
21
+ def __call__(self, img):
22
+ return img.convert("RGB")
23
+
24
+ valid_test_transforms = transforms.Compose([
25
+ transforms.Grayscale(num_output_channels=1),
26
+ transforms.Resize(image_size),
27
+ GrayscaleToRGB(), # Conversion en RGB
28
+ transforms.ToTensor(),
29
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
30
+ ])
31
+
32
+ # Fonction de prédiction
33
+ def predict_image(image):
34
+ image_tensor = valid_test_transforms(image).unsqueeze(0).to(device)
35
+ with torch.no_grad():
36
+ outputs = model(image_tensor)
37
+ _, predicted_class = torch.max(outputs, 1)
38
+ predicted_label = class_names[predicted_class.item()]
39
+ return predicted_label
40
+
41
+ # Interface Gradio
42
+ interface = gr.Interface(
43
+ fn=predict_image,
44
+ inputs=gr.Image(type="pil"),
45
+ outputs="text",
46
+ title="Prédiction d'images avec PyTorch",
47
+ description="Chargez une image pour obtenir une prédiction de classe."
48
+ )
49
+
50
+ if __name__ == "__main__":
51
+ interface.launch()
efficientnet_b7_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3fabd3615d41ea9cc50d46c406ce324312d78cff45faa10fb3a37352421a93e9
3
+ size 262040086
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ Pillow