Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -97,24 +97,23 @@ def model_description():
|
|
97 |
def prediction():
|
98 |
|
99 |
def load_model(model_name):
|
|
|
|
|
100 |
if model_name == "DenseNet":
|
101 |
model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
|
102 |
-
|
103 |
-
num_classes = 7
|
104 |
model = models.densenet121(pretrained=False)
|
105 |
model.classifier = torch.nn.Linear(model.classifier.in_features, num_classes)
|
106 |
-
|
107 |
elif model_name == "MobileNet":
|
108 |
model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
|
109 |
-
|
110 |
-
num_classes = 7
|
111 |
model = models.mobilenet_v2(pretrained=False)
|
112 |
-
model.classifier = torch.nn.Linear(model.classifier[
|
113 |
-
|
114 |
elif model_name == "SqueezeNet":
|
115 |
model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
|
116 |
-
|
117 |
-
num_classes = 7
|
118 |
model = models.squeezenet1_1(pretrained=False)
|
119 |
model.classifier = torch.nn.Sequential(
|
120 |
torch.nn.Dropout(p=0.5),
|
@@ -122,47 +121,48 @@ def prediction():
|
|
122 |
torch.nn.ReLU(),
|
123 |
torch.nn.AdaptiveAvgPool2d((1, 1))
|
124 |
)
|
125 |
-
|
126 |
else:
|
127 |
raise ValueError("Model not supported.")
|
128 |
-
|
129 |
-
|
130 |
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
|
131 |
model.eval()
|
132 |
|
133 |
return model
|
134 |
-
|
|
|
135 |
def process_image(image):
|
|
|
136 |
if image.mode == 'RGBA':
|
137 |
image = image.convert('RGB')
|
138 |
-
|
139 |
-
|
140 |
preprocess = transforms.Compose([
|
141 |
transforms.Resize((224, 224)),
|
142 |
transforms.ToTensor(),
|
143 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
144 |
])
|
145 |
-
|
146 |
img_tensor = preprocess(image)
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
def classify_image(model, image):
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
|
|
155 |
model.eval()
|
156 |
-
|
157 |
with torch.no_grad():
|
158 |
outputs = model(img_tensor)
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
probabilities = torch.nn.functional.softmax(outputs, dim=1)
|
163 |
-
|
|
|
164 |
confidence, predicted = torch.max(probabilities, 1)
|
165 |
-
|
166 |
return predicted.item(), confidence.item()
|
167 |
|
168 |
|
|
|
97 |
def prediction():
|
98 |
|
99 |
def load_model(model_name):
|
100 |
+
num_classes = 7 # Pastikan sesuai dengan jumlah kelas yang digunakan saat training
|
101 |
+
|
102 |
if model_name == "DenseNet":
|
103 |
model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
|
104 |
+
filename="densenet_finetuned.pth")
|
|
|
105 |
model = models.densenet121(pretrained=False)
|
106 |
model.classifier = torch.nn.Linear(model.classifier.in_features, num_classes)
|
107 |
+
|
108 |
elif model_name == "MobileNet":
|
109 |
model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
|
110 |
+
filename="mobileNetV2_finetuned.pth")
|
|
|
111 |
model = models.mobilenet_v2(pretrained=False)
|
112 |
+
model.classifier = torch.nn.Linear(model.classifier[0].in_features, num_classes) # Fix in_features
|
113 |
+
|
114 |
elif model_name == "SqueezeNet":
|
115 |
model_path = hf_hub_download(repo_id="ahmadalfian/mineral-classifier",
|
116 |
+
filename="squeezenet1_finetuned.pth")
|
|
|
117 |
model = models.squeezenet1_1(pretrained=False)
|
118 |
model.classifier = torch.nn.Sequential(
|
119 |
torch.nn.Dropout(p=0.5),
|
|
|
121 |
torch.nn.ReLU(),
|
122 |
torch.nn.AdaptiveAvgPool2d((1, 1))
|
123 |
)
|
|
|
124 |
else:
|
125 |
raise ValueError("Model not supported.")
|
126 |
+
|
127 |
+
# Load model weights
|
128 |
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
|
129 |
model.eval()
|
130 |
|
131 |
return model
|
132 |
+
|
133 |
+
|
134 |
def process_image(image):
|
135 |
+
"""Konversi gambar dan lakukan preprocessing sebelum masuk ke model"""
|
136 |
if image.mode == 'RGBA':
|
137 |
image = image.convert('RGB')
|
138 |
+
|
|
|
139 |
preprocess = transforms.Compose([
|
140 |
transforms.Resize((224, 224)),
|
141 |
transforms.ToTensor(),
|
142 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
143 |
])
|
144 |
+
|
145 |
img_tensor = preprocess(image)
|
146 |
+
return img_tensor.unsqueeze(0) # Tambahkan dimensi batch
|
147 |
+
|
148 |
+
|
149 |
def classify_image(model, image):
|
150 |
+
"""Lakukan prediksi menggunakan model"""
|
151 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
152 |
+
img_tensor = process_image(image).to(device) # Pastikan berada di perangkat yang sesuai
|
153 |
+
|
154 |
+
model.to(device)
|
155 |
model.eval()
|
156 |
+
|
157 |
with torch.no_grad():
|
158 |
outputs = model(img_tensor)
|
159 |
+
|
160 |
+
# Konversi hasil ke probabilitas
|
|
|
161 |
probabilities = torch.nn.functional.softmax(outputs, dim=1)
|
162 |
+
|
163 |
+
# Ambil prediksi dengan confidence tertinggi
|
164 |
confidence, predicted = torch.max(probabilities, 1)
|
165 |
+
|
166 |
return predicted.item(), confidence.item()
|
167 |
|
168 |
|