ahmadalfian commited on
Commit
acd6bba
·
verified ·
1 Parent(s): d04d540

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -48
app.py CHANGED
@@ -98,54 +98,54 @@ def prediction():
98
 
99
  def load_model(model_name):
100
 
101
- if model_name == "DenseNet":
102
- model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
103
- filename="densenet_finetuned.pth")
104
- num_classes = 7
105
- model = models.densenet121(pretrained=False)
106
- model.classifier = torch.nn.Linear(model.classifier.in_features, num_classes)
107
-
108
- elif model_name == "MobileNet":
109
- model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
110
- filename="mobileNetV2_finetuned.pth")
111
- num_classes = 7
112
- model = models.mobilenet_v2(pretrained=False)
113
-
114
- # Muat state_dict, tetapi abaikan classifier lama
115
- state_dict = torch.load(model_path, map_location=torch.device('cpu'))
116
- new_state_dict = {k: v for k, v in state_dict.items() if "classifier.1" not in k} # Hilangkan classifier lama
117
- model.load_state_dict(new_state_dict, strict=False)
118
-
119
- # Definisikan ulang classifier dengan jumlah kelas yang benar
120
- model.classifier = torch.nn.Sequential(
121
- torch.nn.Dropout(0.2),
122
- torch.nn.Linear(model.last_channel, num_classes)
123
- )
124
-
125
- elif model_name == "SqueezeNet":
126
- model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
127
- filename="squeezenet1_finetuned.pth")
128
- num_classes = 7
129
- model = models.squeezenet1_1(pretrained=False)
130
-
131
- # Muat state_dict, tetapi abaikan classifier lama
132
- state_dict = torch.load(model_path, map_location=torch.device('cpu'))
133
- new_state_dict = {k: v for k, v in state_dict.items() if "classifier.1" not in k} # Hilangkan classifier lama
134
- model.load_state_dict(new_state_dict, strict=False)
135
-
136
- # Definisikan ulang classifier dengan jumlah kelas yang sesuai
137
- model.classifier = torch.nn.Sequential(
138
- torch.nn.Dropout(p=0.5),
139
- torch.nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1)),
140
- torch.nn.ReLU(),
141
- torch.nn.AdaptiveAvgPool2d((1, 1))
142
- )
143
-
144
- else:
145
- raise ValueError("Model not supported.")
146
-
147
- model.eval()
148
- return model
149
 
150
 
151
 
 
98
 
99
  def load_model(model_name):
100
 
101
+ if model_name == "DenseNet":
102
+ model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
103
+ filename="densenet_finetuned.pth")
104
+ num_classes = 7
105
+ model = models.densenet121(pretrained=False)
106
+ model.classifier = torch.nn.Linear(model.classifier.in_features, num_classes)
107
+
108
+ elif model_name == "MobileNet":
109
+ model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
110
+ filename="mobileNetV2_finetuned.pth")
111
+ num_classes = 7
112
+ model = models.mobilenet_v2(pretrained=False)
113
+
114
+ # Muat state_dict, tetapi abaikan classifier lama
115
+ state_dict = torch.load(model_path, map_location=torch.device('cpu'))
116
+ new_state_dict = {k: v for k, v in state_dict.items() if "classifier.1" not in k} # Hilangkan classifier lama
117
+ model.load_state_dict(new_state_dict, strict=False)
118
+
119
+ # Definisikan ulang classifier dengan jumlah kelas yang benar
120
+ model.classifier = torch.nn.Sequential(
121
+ torch.nn.Dropout(0.2),
122
+ torch.nn.Linear(model.last_channel, num_classes)
123
+ )
124
+
125
+ elif model_name == "SqueezeNet":
126
+ model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
127
+ filename="squeezenet1_finetuned.pth")
128
+ num_classes = 7
129
+ model = models.squeezenet1_1(pretrained=False)
130
+
131
+ # Muat state_dict, tetapi abaikan classifier lama
132
+ state_dict = torch.load(model_path, map_location=torch.device('cpu'))
133
+ new_state_dict = {k: v for k, v in state_dict.items() if "classifier.1" not in k} # Hilangkan classifier lama
134
+ model.load_state_dict(new_state_dict, strict=False)
135
+
136
+ # Definisikan ulang classifier dengan jumlah kelas yang sesuai
137
+ model.classifier = torch.nn.Sequential(
138
+ torch.nn.Dropout(p=0.5),
139
+ torch.nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1)),
140
+ torch.nn.ReLU(),
141
+ torch.nn.AdaptiveAvgPool2d((1, 1))
142
+ )
143
+
144
+ else:
145
+ raise ValueError("Model not supported.")
146
+
147
+ model.eval()
148
+ return model
149
 
150
 
151