File size: 2,861 Bytes
5760d9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
import cv2
from timm import create_model
from ultralytics import YOLO
import json

# 加载分类模型
with open('class_names.json', 'r') as json_file:
    class_mapping = json.load(json_file)

def load_classification_model(model_path):
    model = create_model('resnet18', pretrained=False, num_classes=len(class_mapping))
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    model.eval()
    return model

classification_model = load_classification_model("res18_nabird555_acc596.pth")

def preprocess_image(image):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    return transform(image).unsqueeze(0)

def classify_image(image):
    image = preprocess_image(image)
    with torch.no_grad():
        outputs = classification_model(image)
        _, predicted_class = torch.max(outputs, 1)
        predicted_class_idx = predicted_class.item()
        return class_mapping[str(predicted_class_idx)]

# 加载检测模型
detection_model = YOLO("nabird_det_ep3.pt")

def detect_and_classify(image: Image.Image):
    image_np = np.array(image)
    image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)

    # 检测鸟类
    results = detection_model.predict(image_np, save=False)
    cropped_birds = []
    classifications = []

    for result in results:
        for box in result.boxes:
            xyxy = box.xyxy[0].tolist()  # [x1, y1, x2, y2]
            x1, y1, x2, y2 = map(int, xyxy)

            # 裁剪鸟类区域
            bird_crop = image.crop((x1, y1, x2, y2))
            cropped_birds.append(bird_crop)

            # 识别鸟类
            class_name = classify_image(bird_crop)
            classifications.append({
                "bbox": [x1, y1, x2, y2],
                "class": class_name
            })

            # 在原图上绘制边框和标签
            cv2.rectangle(image_np, (x1, y1), (x2, y2), color=(0, 255, 0), thickness=2)
            cv2.putText(image_np, class_name, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)

    # 转为 RGB 格式返回
    detected_image = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
    detected_image = Image.fromarray(detected_image)

    return detected_image, classifications

# Gradio 接口
interface = gr.Interface(
    fn=detect_and_classify,
    inputs=gr.Image(type="pil"),
    outputs=[
        gr.Image(type="pil", label="Detected Image"),
        gr.JSON(label="Classifications")
    ],
    title="Bird Detection and Recognition",
    description="Upload an image to detect birds and classify their species."
)

if __name__ == "__main__":
    interface.launch()