Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -97,8 +97,8 @@ def model_description():
|
|
97 |
def prediction():
|
98 |
|
99 |
def load_model(model_name):
|
100 |
-
num_classes = 7 #
|
101 |
-
|
102 |
if model_name == "DenseNet":
|
103 |
model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
|
104 |
filename="densenet_finetuned.pth")
|
@@ -109,7 +109,7 @@ def prediction():
|
|
109 |
model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
|
110 |
filename="mobileNetV2_finetuned.pth")
|
111 |
model = models.mobilenet_v2(pretrained=False)
|
112 |
-
model.classifier = torch.nn.Linear(model.
|
113 |
|
114 |
elif model_name == "SqueezeNet":
|
115 |
model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
|
@@ -164,35 +164,45 @@ def prediction():
|
|
164 |
confidence, predicted = torch.max(probabilities, 1)
|
165 |
|
166 |
return predicted.item(), confidence.item()
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
|
|
|
|
172 |
st.sidebar.header("Model Settings")
|
173 |
model_name = st.sidebar.selectbox("Select Model", ("DenseNet", "SqueezeNet", "MobileNet"))
|
174 |
-
|
175 |
-
confidence_threshold = st.sidebar.number_input("Set Confidence Threshold", min_value=0.0, max_value=1.0, value=0.5,
|
176 |
-
|
177 |
-
|
178 |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
179 |
-
|
180 |
predict_button = st.sidebar.button("Predict")
|
181 |
-
|
|
|
|
|
|
|
182 |
if uploaded_file is not None and predict_button:
|
183 |
with st.spinner("Processing the image..."):
|
184 |
image = Image.open(uploaded_file)
|
185 |
-
|
186 |
model = load_model(model_name)
|
187 |
-
|
188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
col1, col2 = st.columns([1, 2])
|
190 |
-
|
191 |
with col1:
|
192 |
st.image(image, caption='Uploaded Image', width=250)
|
193 |
-
|
194 |
with col2:
|
195 |
-
st.write(f"**Prediction**: {
|
196 |
st.write(f"**Confidence**: {confidence:.4f}")
|
197 |
|
198 |
def contact():
|
|
|
97 |
def prediction():
|
98 |
|
99 |
def load_model(model_name):
|
100 |
+
num_classes = 7 # Sesuaikan dengan jumlah kelas mineral
|
101 |
+
|
102 |
if model_name == "DenseNet":
|
103 |
model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
|
104 |
filename="densenet_finetuned.pth")
|
|
|
109 |
model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
|
110 |
filename="mobileNetV2_finetuned.pth")
|
111 |
model = models.mobilenet_v2(pretrained=False)
|
112 |
+
model.classifier = torch.nn.Linear(model.last_channel, num_classes) # Perbaikan classifier
|
113 |
|
114 |
elif model_name == "SqueezeNet":
|
115 |
model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
|
|
|
164 |
confidence, predicted = torch.max(probabilities, 1)
|
165 |
|
166 |
return predicted.item(), confidence.item()
|
167 |
+
|
168 |
+
|
169 |
+
# ====== Streamlit UI ======
|
170 |
+
|
171 |
+
st.markdown("## 🔍 Mineral Classifier")
|
172 |
+
st.markdown("Upload an image of a mineral, and I will classify it for you!")
|
173 |
+
|
174 |
st.sidebar.header("Model Settings")
|
175 |
model_name = st.sidebar.selectbox("Select Model", ("DenseNet", "SqueezeNet", "MobileNet"))
|
176 |
+
|
177 |
+
confidence_threshold = st.sidebar.number_input("Set Confidence Threshold", min_value=0.0, max_value=1.0, value=0.5, step=0.01)
|
178 |
+
|
|
|
179 |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
180 |
+
|
181 |
predict_button = st.sidebar.button("Predict")
|
182 |
+
|
183 |
+
# Daftar label kelas mineral
|
184 |
+
class_labels = ["quartz", "pyrite", "muscovite", "malachite", "chrysocolla", "bornite", "biotite"]
|
185 |
+
|
186 |
if uploaded_file is not None and predict_button:
|
187 |
with st.spinner("Processing the image..."):
|
188 |
image = Image.open(uploaded_file)
|
189 |
+
|
190 |
model = load_model(model_name)
|
191 |
+
predicted_index, confidence = classify_image(model, image)
|
192 |
+
|
193 |
+
# Konversi index prediksi ke nama mineral
|
194 |
+
if 0 <= predicted_index < len(class_labels):
|
195 |
+
predicted_label = class_labels[predicted_index]
|
196 |
+
else:
|
197 |
+
predicted_label = "Unknown"
|
198 |
+
|
199 |
col1, col2 = st.columns([1, 2])
|
200 |
+
|
201 |
with col1:
|
202 |
st.image(image, caption='Uploaded Image', width=250)
|
203 |
+
|
204 |
with col2:
|
205 |
+
st.write(f"**Prediction**: {predicted_label.capitalize()}") # Perbaiki penggunaan capitalize()
|
206 |
st.write(f"**Confidence**: {confidence:.4f}")
|
207 |
|
208 |
def contact():
|