heuue commited on
Commit
8658411
·
verified ·
1 Parent(s): b8cf0af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -9
app.py CHANGED
@@ -3,15 +3,15 @@ import torch
3
  from torchvision import transforms
4
  from PIL import Image
5
  from timm import create_model
 
 
 
 
6
 
7
- # 定义类别名称
8
- CLASSES = [
9
- "Class 1", "Class 2", "Class 3", # 替换为你的类别名称
10
- ]
11
 
12
  # 加载模型
13
  def load_model(model_path):
14
- model = create_model('resnet18', pretrained=False, num_classes=len(CLASSES))
15
  model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
16
  model.eval()
17
  return model
@@ -32,9 +32,12 @@ def classify_image(image):
32
  image = preprocess_image(image)
33
  with torch.no_grad():
34
  outputs = model(image)
35
- probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
36
- confidences = {CLASSES[i]: float(probabilities[i]) for i in range(len(CLASSES))}
37
- return confidences
 
 
 
38
 
39
  # 创建 Gradio 接口
40
  title = "Bird Species Classifier"
@@ -43,7 +46,7 @@ description = "Upload an image of a bird, and the model will predict its species
43
  interface = gr.Interface(
44
  fn=classify_image,
45
  inputs=gr.Image(type="pil"),
46
- outputs=gr.Label(num_top_classes=3),
47
  title=title,
48
  description=description,
49
  )
 
3
  from torchvision import transforms
4
  from PIL import Image
5
  from timm import create_model
6
+ import json
7
+
8
+ with open('class_mapping.json', 'r') as json_file:
9
+ class_mapping = json.load(json_file)
10
 
 
 
 
 
11
 
12
  # 加载模型
13
  def load_model(model_path):
14
+ model = create_model('resnet18', pretrained=False, num_classes=len(class_mapping))
15
  model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
16
  model.eval()
17
  return model
 
32
  image = preprocess_image(image)
33
  with torch.no_grad():
34
  outputs = model(image)
35
+
36
+ _, predicted_class = torch.max(outputs, 1)
37
+ predicted_class_idx = predicted_class.item()
38
+ predicted_class_name = class_mapping[str(predicted_class_idx)]
39
+
40
+ return predicted_class_name
41
 
42
  # 创建 Gradio 接口
43
  title = "Bird Species Classifier"
 
46
  interface = gr.Interface(
47
  fn=classify_image,
48
  inputs=gr.Image(type="pil"),
49
+ outputs="text",
50
  title=title,
51
  description=description,
52
  )