import gradio as gr import torch from torchvision import transforms from PIL import Image import numpy as np import cv2 from timm import create_model from ultralytics import YOLO import json # 加载分类模型 with open('class_names.json', 'r') as json_file: class_mapping = json.load(json_file) def load_classification_model(model_path): model = create_model('resnet18', pretrained=False, num_classes=len(class_mapping)) model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) model.eval() return model classification_model = load_classification_model("res18_nabird555_acc596.pth") def preprocess_image(image): transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) return transform(image).unsqueeze(0) def classify_image(image): image = preprocess_image(image) with torch.no_grad(): outputs = classification_model(image) _, predicted_class = torch.max(outputs, 1) predicted_class_idx = predicted_class.item() return class_mapping[str(predicted_class_idx)] # 加载检测模型 detection_model = YOLO("nabird_det_ep3.pt") def detect_and_classify(image: Image.Image): image_np = np.array(image) image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) # 检测鸟类 results = detection_model.predict(image_np, save=False) cropped_birds = [] classifications = [] for result in results: for box in result.boxes: xyxy = box.xyxy[0].tolist() # [x1, y1, x2, y2] x1, y1, x2, y2 = map(int, xyxy) # 裁剪鸟类区域 bird_crop = image.crop((x1, y1, x2, y2)) cropped_birds.append(bird_crop) # 识别鸟类 class_name = classify_image(bird_crop) classifications.append({ "bbox": [x1, y1, x2, y2], "class": class_name }) # 在原图上绘制边框和标签 cv2.rectangle(image_np, (x1, y1), (x2, y2), color=(0, 255, 0), thickness=2) cv2.putText(image_np, class_name, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) # 转为 RGB 格式返回 detected_image = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) detected_image = Image.fromarray(detected_image) return detected_image, classifications # Gradio 接口 interface = gr.Interface( fn=detect_and_classify, inputs=gr.Image(type="pil"), outputs=[ gr.Image(type="pil", label="Detected Image"), gr.JSON(label="Classifications") ], title="Bird Detection and Recognition", description="Upload an image to detect birds and classify their species." ) if __name__ == "__main__": interface.launch()