randomshit11 commited on
Commit
590bfb0
·
verified ·
1 Parent(s): 4d5d826

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -51
app.py CHANGED
@@ -1,53 +1,152 @@
1
- import torch
2
- from torchvision import transforms
3
- from PIL import Image
4
- from torchvision import models
5
- import gradio.inputs as gi
6
- import gradio.outputs as go
7
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- # Define the ResNet50 model
10
- class ResNet50(torch.nn.Module):
11
- def __init__(self):
12
- super(ResNet50, self).__init__()
13
- self.resnet = models.resnet50(pretrained=True)
14
- for param in self.resnet.parameters():
15
- param.requires_grad = False
16
- self.resnet.fc = torch.nn.Sequential(
17
- torch.nn.Linear(2048, 2)
18
- )
19
-
20
- def forward(self, x):
21
- x = self.resnet(x)
22
- return x
23
-
24
- # Load the pre-trained model
25
- model = ResNet50()
26
- model.load_state_dict(torch.load('best_modelv2.pth', map_location=torch.device('cpu')))
27
- model.eval()
28
-
29
- # Define transform for input images
30
- data_transforms = transforms.Compose([
31
- transforms.Resize((224, 224)),
32
- transforms.ToTensor(),
33
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
34
- ])
35
-
36
- # Function to predict image label
37
- def predict_image_label(image):
38
- # Preprocess the image
39
- image = data_transforms(image).unsqueeze(0)
40
-
41
- # Make prediction
42
- with torch.no_grad():
43
- output = model(image)
44
- _, predicted = torch.max(output, 1)
45
-
46
- label = 'Leaf' if predicted.item() == 0 else 'Plant'
47
- return label
48
-
49
- # Create Gradio interface
50
- # image = gi.Image(shape=(224, 224))
51
- label = go.Label(num_top_classes=2)
52
-
53
- gr.Interface(fn=predict_image_label,inputs="image", outputs=label, title="Leaf or Plant Classifier").launch()
 
1
+
2
+
3
+
 
 
 
4
  import gradio as gr
5
+ from transformers import pipeline
6
+
7
+ # Load the model pipeline
8
+ pipe = pipeline("image-classification", "dima806/medicinal_plants_image_detection")
9
+
10
+ # Define the image classification function
11
+ def image_classifier(image):
12
+ # Perform image classification
13
+ outputs = pipe(image)
14
+ results = {}
15
+ for result in outputs:
16
+ results[result['label']] = result['score']
17
+ return results
18
+
19
+ # Define app title and description with HTML formatting
20
+ title = "<h1 style='text-align: center; color: #4CAF50;'>Image Classification</h1>"
21
+ description = "<p style='text-align: center; font-size: 18px;'>This application serves to classify skin lesion images based on their skin cancer type. Trained using Vision Transformer (ViT), it has achieved a validation accuracy of 86%.</p>"
22
+
23
+ # Define custom CSS styles for the Gradio app
24
+ custom_css = """
25
+ .gradio-interface {
26
+ max-width: 600px;
27
+ margin: auto;
28
+ border-radius: 10px;
29
+ box-shadow: 0px 0px 10px rgba(0, 0, 0, 0.1);
30
+ }
31
+ .title-container {
32
+ padding: 20px;
33
+ background-color: #f0f0f0;
34
+ border-top-left-radius: 10px;
35
+ border-top-right-radius: 10px;
36
+ }
37
+ .description-container {
38
+ padding: 20px;
39
+ }
40
+ """
41
+
42
+ # Launch the Gradio interface with custom HTML and CSS
43
+ demo = gr.Interface(fn=image_classifier, inputs=gr.Image(type="pil"), outputs="label", title=title, description=description,
44
+ theme="gstaff/sketch", css=custom_css,
45
+ )
46
+ demo.launch()
47
+
48
+
49
+
50
+
51
+
52
+
53
+
54
+
55
+
56
+
57
+
58
+
59
+
60
+
61
+
62
+
63
+
64
+
65
+
66
+
67
+
68
+
69
+
70
+
71
+
72
+
73
+
74
+
75
+
76
+
77
+
78
+
79
+
80
+
81
+
82
+
83
+
84
+
85
+
86
+
87
+
88
+
89
+
90
+
91
+
92
+
93
+
94
+
95
+
96
+
97
+
98
+
99
+
100
+ # import torch
101
+ # from torchvision import transforms
102
+ # from PIL import Image
103
+ # from torchvision import models
104
+ # import gradio.inputs as gi
105
+ # import gradio.outputs as go
106
+ # import gradio as gr
107
+
108
+ # # Define the ResNet50 model
109
+ # class ResNet50(torch.nn.Module):
110
+ # def __init__(self):
111
+ # super(ResNet50, self).__init__()
112
+ # self.resnet = models.resnet50(pretrained=True)
113
+ # for param in self.resnet.parameters():
114
+ # param.requires_grad = False
115
+ # self.resnet.fc = torch.nn.Sequential(
116
+ # torch.nn.Linear(2048, 2)
117
+ # )
118
+
119
+ # def forward(self, x):
120
+ # x = self.resnet(x)
121
+ # return x
122
+
123
+ # # Load the pre-trained model
124
+ # model = ResNet50()
125
+ # model.load_state_dict(torch.load('best_modelv2.pth', map_location=torch.device('cpu')))
126
+ # model.eval()
127
+
128
+ # # Define transform for input images
129
+ # data_transforms = transforms.Compose([
130
+ # transforms.Resize((224, 224)),
131
+ # transforms.ToTensor(),
132
+ # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
133
+ # ])
134
+
135
+ # # Function to predict image label
136
+ # def predict_image_label(image):
137
+ # # Preprocess the image
138
+ # image = data_transforms(image).unsqueeze(0)
139
+
140
+ # # Make prediction
141
+ # with torch.no_grad():
142
+ # output = model(image)
143
+ # _, predicted = torch.max(output, 1)
144
+
145
+ # label = 'Leaf' if predicted.item() == 0 else 'Plant'
146
+ # return label
147
+
148
+ # # Create Gradio interface
149
+ # # image = gi.Image(shape=(224, 224))
150
+ # label = go.Label(num_top_classes=2)
151
 
152
+ # gr.Interface(fn=predict_image_label,inputs="image", outputs=label, title="Leaf or Plant Classifier").launch()