ahmadalfian commited on
Commit
c361455
·
verified ·
1 Parent(s): 2a956ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -20
app.py CHANGED
@@ -97,8 +97,8 @@ def model_description():
97
  def prediction():
98
 
99
  def load_model(model_name):
100
- num_classes = 7 # Pastikan sesuai dengan jumlah kelas yang digunakan saat training
101
-
102
  if model_name == "DenseNet":
103
  model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
104
  filename="densenet_finetuned.pth")
@@ -109,7 +109,7 @@ def prediction():
109
  model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
110
  filename="mobileNetV2_finetuned.pth")
111
  model = models.mobilenet_v2(pretrained=False)
112
- model.classifier = torch.nn.Linear(model.classifier[0].in_features, num_classes) # Fix in_features
113
 
114
  elif model_name == "SqueezeNet":
115
  model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
@@ -164,35 +164,45 @@ def prediction():
164
  confidence, predicted = torch.max(probabilities, 1)
165
 
166
  return predicted.item(), confidence.item()
167
-
168
-
169
- st.markdown("## 🍎 Fruit and Vegetable Classifier")
170
- st.markdown("Upload an image of a fruit or vegetable, and I will classify it for you!")
171
-
 
 
172
  st.sidebar.header("Model Settings")
173
  model_name = st.sidebar.selectbox("Select Model", ("DenseNet", "SqueezeNet", "MobileNet"))
174
-
175
- confidence_threshold = st.sidebar.number_input("Set Confidence Threshold", min_value=0.0, max_value=1.0, value=0.5,
176
- step=0.01)
177
-
178
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
179
-
180
  predict_button = st.sidebar.button("Predict")
181
-
 
 
 
182
  if uploaded_file is not None and predict_button:
183
  with st.spinner("Processing the image..."):
184
  image = Image.open(uploaded_file)
185
-
186
  model = load_model(model_name)
187
- label, confidence = classify_image(model, image)
188
-
 
 
 
 
 
 
189
  col1, col2 = st.columns([1, 2])
190
-
191
  with col1:
192
  st.image(image, caption='Uploaded Image', width=250)
193
-
194
  with col2:
195
- st.write(f"**Prediction**: {predicted.capitalize()}")
196
  st.write(f"**Confidence**: {confidence:.4f}")
197
 
198
  def contact():
 
97
  def prediction():
98
 
99
  def load_model(model_name):
100
+ num_classes = 7 # Sesuaikan dengan jumlah kelas mineral
101
+
102
  if model_name == "DenseNet":
103
  model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
104
  filename="densenet_finetuned.pth")
 
109
  model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
110
  filename="mobileNetV2_finetuned.pth")
111
  model = models.mobilenet_v2(pretrained=False)
112
+ model.classifier = torch.nn.Linear(model.last_channel, num_classes) # Perbaikan classifier
113
 
114
  elif model_name == "SqueezeNet":
115
  model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
 
164
  confidence, predicted = torch.max(probabilities, 1)
165
 
166
  return predicted.item(), confidence.item()
167
+
168
+
169
+ # ====== Streamlit UI ======
170
+
171
+ st.markdown("## 🔍 Mineral Classifier")
172
+ st.markdown("Upload an image of a mineral, and I will classify it for you!")
173
+
174
  st.sidebar.header("Model Settings")
175
  model_name = st.sidebar.selectbox("Select Model", ("DenseNet", "SqueezeNet", "MobileNet"))
176
+
177
+ confidence_threshold = st.sidebar.number_input("Set Confidence Threshold", min_value=0.0, max_value=1.0, value=0.5, step=0.01)
178
+
 
179
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
180
+
181
  predict_button = st.sidebar.button("Predict")
182
+
183
+ # Daftar label kelas mineral
184
+ class_labels = ["quartz", "pyrite", "muscovite", "malachite", "chrysocolla", "bornite", "biotite"]
185
+
186
  if uploaded_file is not None and predict_button:
187
  with st.spinner("Processing the image..."):
188
  image = Image.open(uploaded_file)
189
+
190
  model = load_model(model_name)
191
+ predicted_index, confidence = classify_image(model, image)
192
+
193
+ # Konversi index prediksi ke nama mineral
194
+ if 0 <= predicted_index < len(class_labels):
195
+ predicted_label = class_labels[predicted_index]
196
+ else:
197
+ predicted_label = "Unknown"
198
+
199
  col1, col2 = st.columns([1, 2])
200
+
201
  with col1:
202
  st.image(image, caption='Uploaded Image', width=250)
203
+
204
  with col2:
205
+ st.write(f"**Prediction**: {predicted_label.capitalize()}") # Perbaiki penggunaan capitalize()
206
  st.write(f"**Confidence**: {confidence:.4f}")
207
 
208
  def contact():