Racso777 commited on
Commit
d998353
·
1 Parent(s): 719e422

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -0
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ from PIL import Image
5
+ from torchvision import transforms
6
+ import matplotlib.pyplot as plt
7
+ import gradio as gr
8
+ from io import BytesIO
9
+
10
+ from vit_model import vit_base_patch16_224_in21k as create_model
11
+
12
+ def classify_image(img):
13
+ # Your existing code here, modified to use `img_path` as input
14
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
15
+
16
+ data_transform = transforms.Compose(
17
+ [transforms.Resize(256),
18
+ transforms.CenterCrop(224),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
21
+
22
+ # [N, C, H, W]
23
+ img = data_transform(img)
24
+ # expand batch dimension
25
+ img = torch.unsqueeze(img, dim=0)
26
+
27
+ # read class_indict
28
+ json_path = 'F:\mushroom_project\VIT\class_indices.json'
29
+ assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
30
+
31
+ with open(json_path, "r") as f:
32
+ class_indict = json.load(f)
33
+
34
+ # create model
35
+ model = create_model(num_classes=370, has_logits=False).to(device)
36
+ # load model weights
37
+ model_weight_path = "F:\mushroom_project\VIT\pretrain_30_weights\\best_model.pth"
38
+ #load no pretrain model path
39
+ #model_weight_path = "F:\mushroom_project\VIT\no_pretrain_weights\best_model.pth"
40
+ model.load_state_dict(torch.load(model_weight_path, map_location=device))
41
+ model.eval()
42
+ with torch.no_grad():
43
+ # predict class
44
+ output = torch.squeeze(model(img.to(device))).cpu()
45
+ predict = torch.softmax(output, dim=0)
46
+ predict_cla = torch.argmax(predict).numpy()
47
+
48
+ print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
49
+ predict[predict_cla].numpy())
50
+
51
+ # Combine the two lists into a list of tuples
52
+ combined_list = list(zip(class_indict, predict))
53
+
54
+ # Sort the combined list by the 'predict' values in descending order
55
+ sorted_combined_list = sorted(combined_list, key=lambda x: x[1], reverse=True)
56
+
57
+ # Determine the position you are currently interested in
58
+ current_position = 5 # Example position
59
+
60
+ # Get the previous five elements from the sorted list
61
+ # Ensure that the index does not go below zero
62
+ start_index = max(current_position - 5, 0)
63
+ previous_five = sorted_combined_list[start_index:current_position]
64
+
65
+ joined_string = ""
66
+ for i in previous_five:
67
+ #print("class: {:10} prob: {:.3}".format(class_indict[str(i[0])], i[1].numpy()))
68
+ joined_string += ("class: {:10} prob: {:.3}".format(class_indict[str(i[0])], i[1].numpy())) + "\n"
69
+
70
+ #print(joined_string)
71
+ plt.title(joined_string)
72
+ plt.tight_layout()
73
+ fig = plt.figure()
74
+ return joined_string
75
+
76
+ # Create a Gradio interface
77
+ iface = gr.Interface(
78
+ fn=classify_image,
79
+ inputs=gr.Image(type='pil'),
80
+ outputs=gr.Textbox(),
81
+ title="Mushrrom Image Classification",
82
+ description="Upload a mushroom image to classify."
83
+ )
84
+
85
+ # Run the Gradio app
86
+ #if __name__ == '__main__':
87
+ iface.launch()