heuue commited on
Commit
b816225
·
verified ·
1 Parent(s): dd66d52

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -0
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ from timm import create_model
6
+
7
+ # 定义类别名称
8
+ CLASSES = [
9
+ "Class 1", "Class 2", "Class 3", # 替换为你的类别名称
10
+ ]
11
+
12
+ # 加载模型
13
+ def load_model(model_path):
14
+ model = create_model('resnet18', pretrained=False, num_classes=len(CLASSES))
15
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
16
+ model.eval()
17
+ return model
18
+
19
+ model = load_model("model.pth")
20
+
21
+ # 定义图像预处理
22
+ def preprocess_image(image):
23
+ transform = transforms.Compose([
24
+ transforms.Resize((224, 224)),
25
+ transforms.ToTensor(),
26
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
27
+ ])
28
+ return transform(image).unsqueeze(0)
29
+
30
+ # 定义推理函数
31
+ def classify_image(image):
32
+ image = preprocess_image(image)
33
+ with torch.no_grad():
34
+ outputs = model(image)
35
+ probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
36
+ confidences = {CLASSES[i]: float(probabilities[i]) for i in range(len(CLASSES))}
37
+ return confidences
38
+
39
+ # 创建 Gradio 接口
40
+ title = "Bird Species Classifier"
41
+ description = "Upload an image of a bird, and the model will predict its species."
42
+
43
+ interface = gr.Interface(
44
+ fn=classify_image,
45
+ inputs=gr.Image(type="pil"),
46
+ outputs=gr.Label(num_top_classes=3),
47
+ title=title,
48
+ description=description,
49
+ )
50
+
51
+ if __name__ == "__main__":
52
+ interface.launch()