Update app.py
Browse files
app.py
CHANGED
@@ -25,7 +25,7 @@ def classify_image(img):
|
|
25 |
img = torch.unsqueeze(img, dim=0)
|
26 |
|
27 |
# read class_indict
|
28 |
-
json_path = '
|
29 |
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
|
30 |
|
31 |
with open(json_path, "r") as f:
|
@@ -34,9 +34,8 @@ def classify_image(img):
|
|
34 |
# create model
|
35 |
model = create_model(num_classes=370, has_logits=False).to(device)
|
36 |
# load model weights
|
37 |
-
model_weight_path = "
|
38 |
-
|
39 |
-
#model_weight_path = "F:\mushroom_project\VIT\no_pretrain_weights\best_model.pth"
|
40 |
model.load_state_dict(torch.load(model_weight_path, map_location=device))
|
41 |
model.eval()
|
42 |
with torch.no_grad():
|
|
|
25 |
img = torch.unsqueeze(img, dim=0)
|
26 |
|
27 |
# read class_indict
|
28 |
+
json_path = './class_indices.json'
|
29 |
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
|
30 |
|
31 |
with open(json_path, "r") as f:
|
|
|
34 |
# create model
|
35 |
model = create_model(num_classes=370, has_logits=False).to(device)
|
36 |
# load model weights
|
37 |
+
model_weight_path = "./best_model.pth"
|
38 |
+
|
|
|
39 |
model.load_state_dict(torch.load(model_weight_path, map_location=device))
|
40 |
model.eval()
|
41 |
with torch.no_grad():
|