ahmadalfian commited on
Commit
8539ecf
·
verified ·
1 Parent(s): c361455

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -32
app.py CHANGED
@@ -97,38 +97,54 @@ def model_description():
97
  def prediction():
98
 
99
  def load_model(model_name):
100
- num_classes = 7 # Sesuaikan dengan jumlah kelas mineral
101
-
102
- if model_name == "DenseNet":
103
- model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
104
- filename="densenet_finetuned.pth")
105
- model = models.densenet121(pretrained=False)
106
- model.classifier = torch.nn.Linear(model.classifier.in_features, num_classes)
107
-
108
- elif model_name == "MobileNet":
109
- model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
110
- filename="mobileNetV2_finetuned.pth")
111
- model = models.mobilenet_v2(pretrained=False)
112
- model.classifier = torch.nn.Linear(model.last_channel, num_classes) # Perbaikan classifier
113
-
114
- elif model_name == "SqueezeNet":
115
- model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
116
- filename="squeezenet1_finetuned.pth")
117
- model = models.squeezenet1_1(pretrained=False)
118
- model.classifier = torch.nn.Sequential(
119
- torch.nn.Dropout(p=0.5),
120
- torch.nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1)),
121
- torch.nn.ReLU(),
122
- torch.nn.AdaptiveAvgPool2d((1, 1))
123
- )
124
- else:
125
- raise ValueError("Model not supported.")
126
-
127
- # Load model weights
128
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
129
- model.eval()
130
-
131
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
 
134
  def process_image(image):
 
97
  def prediction():
98
 
99
  def load_model(model_name):
100
+ num_classes = 7 # Sesuaikan dengan jumlah kelas mineral yang digunakan
101
+
102
+ if model_name == "DenseNet":
103
+ model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
104
+ filename="densenet_finetuned.pth")
105
+ model = models.densenet121(pretrained=False)
106
+ model.classifier = torch.nn.Linear(model.classifier.in_features, num_classes)
107
+
108
+ elif model_name == "MobileNet":
109
+ model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
110
+ filename="mobileNetV2_finetuned.pth")
111
+ model = models.mobilenet_v2(pretrained=False)
112
+
113
+ # Muat state_dict, tetapi abaikan classifier lama
114
+ state_dict = torch.load(model_path, map_location=torch.device('cpu'))
115
+ new_state_dict = {k: v for k, v in state_dict.items() if "classifier.1" not in k} # Hilangkan classifier lama
116
+ model.load_state_dict(new_state_dict, strict=False)
117
+
118
+ # Definisikan ulang classifier dengan jumlah kelas yang benar
119
+ model.classifier = torch.nn.Sequential(
120
+ torch.nn.Dropout(0.2),
121
+ torch.nn.Linear(model.last_channel, num_classes)
122
+ )
123
+
124
+ elif model_name == "SqueezeNet":
125
+ model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
126
+ filename="squeezenet1_finetuned.pth")
127
+ model = models.squeezenet1_1(pretrained=False)
128
+
129
+ # Muat state_dict, tetapi abaikan classifier lama
130
+ state_dict = torch.load(model_path, map_location=torch.device('cpu'))
131
+ new_state_dict = {k: v for k, v in state_dict.items() if "classifier.1" not in k} # Hilangkan classifier lama
132
+ model.load_state_dict(new_state_dict, strict=False)
133
+
134
+ # Definisikan ulang classifier dengan jumlah kelas yang sesuai
135
+ model.classifier = torch.nn.Sequential(
136
+ torch.nn.Dropout(p=0.5),
137
+ torch.nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1)),
138
+ torch.nn.ReLU(),
139
+ torch.nn.AdaptiveAvgPool2d((1, 1))
140
+ )
141
+
142
+ else:
143
+ raise ValueError("Model not supported.")
144
+
145
+ model.eval()
146
+ return model
147
+
148
 
149
 
150
  def process_image(image):