hasnanmr commited on
Commit
654e088
·
1 Parent(s): d1d6890

modify model

Browse files
Files changed (1) hide show
  1. app.py +12 -12
app.py CHANGED
@@ -6,19 +6,19 @@ from torchvision import models, transforms
6
  from torch.utils.data import DataLoader
7
  from torchvision.datasets import ImageFolder
8
 
9
- vgg16 = models.vgg16(pretrained=True)
10
-
11
- # Freeze the convolutional base to prevent updating weights during training
12
- for param in vgg16.features.parameters():
13
  param.requires_grad = False
14
 
15
- num_features = vgg16.classifier[6].in_features
16
- num_classes = 3
17
- vgg16.classifier[6] = torch.nn.Linear(num_features, num_classes)
 
18
 
19
  # Load the model
20
- model = vgg16
21
- state_dict = torch.load('vgg16_transfer_learning.pth', map_location=torch.device('cpu'))
22
  model.load_state_dict(state_dict)
23
  model.eval()
24
 
@@ -26,10 +26,10 @@ model.eval()
26
  transform = transforms.Compose([
27
  transforms.Resize((224, 224)),
28
  transforms.ToTensor(),
29
- transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
30
  ])
31
 
32
- classes = ('broccoli', 'cabbage', 'cauliflower')
33
 
34
  def predict(image):
35
  input_tensor = transform(image)
@@ -42,7 +42,7 @@ def predict(image):
42
  max_value, predicted_class = torch.max(probabilities, 0)
43
  return classes[predicted_class.item()], max_value.item() * 100
44
 
45
- st.title('Vegetable Classification')
46
  st.write('you can upload your image of veggies below')
47
 
48
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
 
6
  from torch.utils.data import DataLoader
7
  from torchvision.datasets import ImageFolder
8
 
9
+ #define the model architecture
10
+ model_resnet = models.resnet18(weights='IMAGENET1K_V1')
11
+ for param in model_resnet.parameters():
 
12
  param.requires_grad = False
13
 
14
+ # Parameters of newly constructed modules have requires_grad=True by default
15
+ num_ftrs = model_resnet.fc.in_features
16
+ model_resnet.fc = nn.Linear(num_ftrs, 15) #mengganti jumlah classifier sesuai output kelas
17
+
18
 
19
  # Load the model
20
+ model = model_resnet
21
+ state_dict = torch.load('transfer_learning_resnet_15class.pth', map_location=torch.device('cpu'))
22
  model.load_state_dict(state_dict)
23
  model.eval()
24
 
 
26
  transform = transforms.Compose([
27
  transforms.Resize((224, 224)),
28
  transforms.ToTensor(),
29
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
30
  ])
31
 
32
+ classes = ('Bean', 'Bitter_Gourd', 'Bottle_Gourd', 'Brinjal', 'Broccoli', 'Cabbage', 'Capsicum', 'Carrot', 'Cauliflower', 'Cucumber', 'Papaya', 'Potato', 'Pumpkin', 'Radish', 'Tomato')
33
 
34
  def predict(image):
35
  input_tensor = transform(image)
 
42
  max_value, predicted_class = torch.max(probabilities, 0)
43
  return classes[predicted_class.item()], max_value.item() * 100
44
 
45
+ st.title('Vegetable Classification for learning')
46
  st.write('you can upload your image of veggies below')
47
 
48
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])