ahmadalfian commited on
Commit
2a956ed
·
verified ·
1 Parent(s): edfe582

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -29
app.py CHANGED
@@ -97,24 +97,23 @@ def model_description():
97
  def prediction():
98
 
99
  def load_model(model_name):
 
 
100
  if model_name == "DenseNet":
101
  model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
102
- filename="densenet_finetuned.pth")
103
- num_classes = 7
104
  model = models.densenet121(pretrained=False)
105
  model.classifier = torch.nn.Linear(model.classifier.in_features, num_classes)
106
-
107
  elif model_name == "MobileNet":
108
  model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
109
- filename="mobileNetV2_finetuned.pth")
110
- num_classes = 7
111
  model = models.mobilenet_v2(pretrained=False)
112
- model.classifier = torch.nn.Linear(model.classifier[1].in_features, num_classes)
113
-
114
  elif model_name == "SqueezeNet":
115
  model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
116
- filename="squeezenet1_finetuned.pth")
117
- num_classes = 7
118
  model = models.squeezenet1_1(pretrained=False)
119
  model.classifier = torch.nn.Sequential(
120
  torch.nn.Dropout(p=0.5),
@@ -122,47 +121,48 @@ def prediction():
122
  torch.nn.ReLU(),
123
  torch.nn.AdaptiveAvgPool2d((1, 1))
124
  )
125
-
126
  else:
127
  raise ValueError("Model not supported.")
128
-
129
- # Load model weights
130
  model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
131
  model.eval()
132
 
133
  return model
134
-
 
135
  def process_image(image):
 
136
  if image.mode == 'RGBA':
137
  image = image.convert('RGB')
138
- print("Image converted from RGBA to RGB.")
139
-
140
  preprocess = transforms.Compose([
141
  transforms.Resize((224, 224)),
142
  transforms.ToTensor(),
143
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
144
  ])
145
-
146
  img_tensor = preprocess(image)
147
- print(f"Image tensor shape: {img_tensor.shape}")
148
- return img_tensor.unsqueeze(0)
149
-
150
  def classify_image(model, image):
151
- img_tensor = process_image(image)
152
-
153
- print(f"Image tensor shape after unsqueeze: {img_tensor.shape}")
154
-
 
155
  model.eval()
156
-
157
  with torch.no_grad():
158
  outputs = model(img_tensor)
159
-
160
- print(f"Model output shape: {outputs.shape}")
161
-
162
  probabilities = torch.nn.functional.softmax(outputs, dim=1)
163
-
 
164
  confidence, predicted = torch.max(probabilities, 1)
165
-
166
  return predicted.item(), confidence.item()
167
 
168
 
 
97
  def prediction():
98
 
99
  def load_model(model_name):
100
+ num_classes = 7 # Pastikan sesuai dengan jumlah kelas yang digunakan saat training
101
+
102
  if model_name == "DenseNet":
103
  model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
104
+ filename="densenet_finetuned.pth")
 
105
  model = models.densenet121(pretrained=False)
106
  model.classifier = torch.nn.Linear(model.classifier.in_features, num_classes)
107
+
108
  elif model_name == "MobileNet":
109
  model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
110
+ filename="mobileNetV2_finetuned.pth")
 
111
  model = models.mobilenet_v2(pretrained=False)
112
+ model.classifier = torch.nn.Linear(model.classifier[0].in_features, num_classes) # Fix in_features
113
+
114
  elif model_name == "SqueezeNet":
115
  model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
116
+ filename="squeezenet1_finetuned.pth")
 
117
  model = models.squeezenet1_1(pretrained=False)
118
  model.classifier = torch.nn.Sequential(
119
  torch.nn.Dropout(p=0.5),
 
121
  torch.nn.ReLU(),
122
  torch.nn.AdaptiveAvgPool2d((1, 1))
123
  )
 
124
  else:
125
  raise ValueError("Model not supported.")
126
+
127
+ # Load model weights
128
  model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
129
  model.eval()
130
 
131
  return model
132
+
133
+
134
  def process_image(image):
135
+ """Konversi gambar dan lakukan preprocessing sebelum masuk ke model"""
136
  if image.mode == 'RGBA':
137
  image = image.convert('RGB')
138
+
 
139
  preprocess = transforms.Compose([
140
  transforms.Resize((224, 224)),
141
  transforms.ToTensor(),
142
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
143
  ])
144
+
145
  img_tensor = preprocess(image)
146
+ return img_tensor.unsqueeze(0) # Tambahkan dimensi batch
147
+
148
+
149
  def classify_image(model, image):
150
+ """Lakukan prediksi menggunakan model"""
151
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
152
+ img_tensor = process_image(image).to(device) # Pastikan berada di perangkat yang sesuai
153
+
154
+ model.to(device)
155
  model.eval()
156
+
157
  with torch.no_grad():
158
  outputs = model(img_tensor)
159
+
160
+ # Konversi hasil ke probabilitas
 
161
  probabilities = torch.nn.functional.softmax(outputs, dim=1)
162
+
163
+ # Ambil prediksi dengan confidence tertinggi
164
  confidence, predicted = torch.max(probabilities, 1)
165
+
166
  return predicted.item(), confidence.item()
167
 
168