ahmadalfian commited on
Commit
d04d540
·
verified ·
1 Parent(s): 8539ecf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -97,17 +97,18 @@ def model_description():
97
  def prediction():
98
 
99
  def load_model(model_name):
100
- num_classes = 7 # Sesuaikan dengan jumlah kelas mineral yang digunakan
101
 
102
  if model_name == "DenseNet":
103
  model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
104
  filename="densenet_finetuned.pth")
 
105
  model = models.densenet121(pretrained=False)
106
  model.classifier = torch.nn.Linear(model.classifier.in_features, num_classes)
107
 
108
  elif model_name == "MobileNet":
109
  model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
110
  filename="mobileNetV2_finetuned.pth")
 
111
  model = models.mobilenet_v2(pretrained=False)
112
 
113
  # Muat state_dict, tetapi abaikan classifier lama
@@ -124,6 +125,7 @@ def prediction():
124
  elif model_name == "SqueezeNet":
125
  model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
126
  filename="squeezenet1_finetuned.pth")
 
127
  model = models.squeezenet1_1(pretrained=False)
128
 
129
  # Muat state_dict, tetapi abaikan classifier lama
 
97
  def prediction():
98
 
99
  def load_model(model_name):
 
100
 
101
  if model_name == "DenseNet":
102
  model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
103
  filename="densenet_finetuned.pth")
104
+ num_classes = 7
105
  model = models.densenet121(pretrained=False)
106
  model.classifier = torch.nn.Linear(model.classifier.in_features, num_classes)
107
 
108
  elif model_name == "MobileNet":
109
  model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
110
  filename="mobileNetV2_finetuned.pth")
111
+ num_classes = 7
112
  model = models.mobilenet_v2(pretrained=False)
113
 
114
  # Muat state_dict, tetapi abaikan classifier lama
 
125
  elif model_name == "SqueezeNet":
126
  model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
127
  filename="squeezenet1_finetuned.pth")
128
+ num_classes = 7
129
  model = models.squeezenet1_1(pretrained=False)
130
 
131
  # Muat state_dict, tetapi abaikan classifier lama