File size: 3,393 Bytes
0cb853f
 
 
 
c2469bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cb853f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# utils.py
import torch
import supervision as sv
from PIL import Image
import os
from typing import Union, Any, Tuple, Dict
from unittest.mock import patch

import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
from transformers.dynamic_module_utils import get_imports

FLORENCE_CHECKPOINT = "microsoft/Florence-2-base"
FLORENCE_OBJECT_DETECTION_TASK = '<OD>'
FLORENCE_DETAILED_CAPTION_TASK = '<MORE_DETAILED_CAPTION>'
FLORENCE_CAPTION_TO_PHRASE_GROUNDING_TASK = '<CAPTION_TO_PHRASE_GROUNDING>'
FLORENCE_OPEN_VOCABULARY_DETECTION_TASK = '<OPEN_VOCABULARY_DETECTION>'
FLORENCE_DENSE_REGION_CAPTION_TASK = '<DENSE_REGION_CAPTION>'


def fixed_get_imports(filename: Union[str, os.PathLike]) -> list[str]:
    """Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72."""
    if not str(filename).endswith("/modeling_florence2.py"):
        return get_imports(filename)
    imports = get_imports(filename)
    imports.remove("flash_attn")
    return imports


def load_florence_model(
    device: torch.device, checkpoint: str = FLORENCE_CHECKPOINT
) -> Tuple[Any, Any]:
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
    model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-base", torch_dtype=torch_dtype, trust_remote_code=True).to(device)
    processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
    return model, processor

    
def run_florence_inference(
    model: Any,
    processor: Any,
    device: torch.device,
    image: Image,
    task: str,
    text: str = ""
) -> Tuple[str, Dict]:
    prompt = task + text
    inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
    generated_ids = model.generate(
        input_ids=inputs["input_ids"],
        pixel_values=inputs["pixel_values"],
        max_new_tokens=1024,
        num_beams=3
    )
    generated_text = processor.batch_decode(
        generated_ids, skip_special_tokens=False)[0]
    response = processor.post_process_generation(
        generated_text, task=task, image_size=image.size)
    return generated_text, response


def detect_objects_in_image(image_input_path, texts, device):
    # 加载图像
    image_input = Image.open(image_input_path)

    # 初始化检测列表
    detections_list = []

    # 对每个文本进行检测
    for text in texts:
        _, result = run_florence_inference(
            model=FLORENCE_MODEL.to(device),
            processor=FLORENCE_PROCESSOR,
            device=device,
            image=image_input,
            task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
            text=text
        )

        # 从结果中构建监督检测对象
        detections = sv.Detections.from_lmm(
            lmm=sv.LMM.FLORENCE_2,
            result=result,
            resolution_wh=image_input.size
        )

        # 运行 SAM 推理
        detections = run_sam_inference(SAM_IMAGE_MODEL.to(device), image_input, detections)

        # 将检测结果添加到列表中
        detections_list.append(detections)

    # 合并所有检测结果
    detections = sv.Detections.merge(detections_list)

    # 再次运行 SAM 推理
    detections = run_sam_inference(SAM_IMAGE_MODEL.to(device), image_input, detections)

    return detections