ahmadalfian commited on
Commit
05b0454
·
verified ·
1 Parent(s): 8745125

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -27
app.py CHANGED
@@ -97,41 +97,40 @@ def model_description():
97
  def prediction():
98
 
99
  def load_model(model_name):
100
-
101
- if model_name == "DenseNet":
102
- model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
103
  filename="densenet_finetuned.pth")
104
- num_classes = 7
105
- model = models.densenet121(pretrained=False)
106
- model.classifier = torch.nn.Linear(model.classifier.in_features, num_classes)
107
 
108
- elif model_name == "MobileNet":
109
- model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
110
  filename="mobileNetV2_finetuned.pth")
111
- num_classes = 7
112
- model = models.mobilenet_v2(pretrained=False)
113
- model.classifier = torch.nn.Linear(model.classifier[1].in_features, num_classes)
114
 
115
- elif model_name == "SqueezeNet":
116
- model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
117
  filename="squeezenet1_finetuned.pth")
118
- num_classes = 7
119
- model = models.squeezenet1_1(pretrained=False)
120
- model.classifier = torch.nn.Sequential(
121
- torch.nn.Dropout(p=0.5),
122
- torch.nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1)),
123
- torch.nn.ReLU(),
124
- torch.nn.AdaptiveAvgPool2d((1, 1))
125
- )
126
-
127
- else:
128
- raise ValueError("Model not supported.")
129
 
130
  # Load model weights
131
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
132
- model.eval()
133
 
134
- return model
135
 
136
  def process_image(image):
137
  if image.mode == 'RGBA':
 
97
  def prediction():
98
 
99
  def load_model(model_name):
100
+ if model_name == "DenseNet":
101
+ model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
 
102
  filename="densenet_finetuned.pth")
103
+ num_classes = 7
104
+ model = models.densenet121(pretrained=False)
105
+ model.classifier = torch.nn.Linear(model.classifier.in_features, num_classes)
106
 
107
+ elif model_name == "MobileNet":
108
+ model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
109
  filename="mobileNetV2_finetuned.pth")
110
+ num_classes = 7
111
+ model = models.mobilenet_v2(pretrained=False)
112
+ model.classifier = torch.nn.Linear(model.classifier[1].in_features, num_classes)
113
 
114
+ elif model_name == "SqueezeNet":
115
+ model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
116
  filename="squeezenet1_finetuned.pth")
117
+ num_classes = 7
118
+ model = models.squeezenet1_1(pretrained=False)
119
+ model.classifier = torch.nn.Sequential(
120
+ torch.nn.Dropout(p=0.5),
121
+ torch.nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1)),
122
+ torch.nn.ReLU(),
123
+ torch.nn.AdaptiveAvgPool2d((1, 1))
124
+ )
125
+
126
+ else:
127
+ raise ValueError("Model not supported.")
128
 
129
  # Load model weights
130
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
131
+ model.eval()
132
 
133
+ return model
134
 
135
  def process_image(image):
136
  if image.mode == 'RGBA':