File size: 16,414 Bytes
12d9ea9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
import logging
from typing import Dict, List, Tuple, Optional, Any

class ObjectGroupProcessor:
    """
    物件組處理器 - 專門處理物件分組、排序和子句生成的邏輯
    負責物件按類別分組、重複物件檢測移除、物件組優先級排序以及描述子句的生成
    """

    def __init__(self, confidence_threshold_for_description: float = 0.25,
                 spatial_handler: Optional[Any] = None,
                 text_optimizer: Optional[Any] = None):
        """
        初始化物件組處理器

        Args:
            confidence_threshold_for_description: 用於描述的置信度閾值
            spatial_handler: 空間位置處理器實例
            text_optimizer: 文本優化器實例
        """
        self.logger = logging.getLogger(self.__class__.__name__)
        self.confidence_threshold_for_description = confidence_threshold_for_description
        self.spatial_handler = spatial_handler
        self.text_optimizer = text_optimizer

    def group_objects_by_class(self, confident_objects: List[Dict],
                              object_statistics: Optional[Dict]) -> Dict[str, List[Dict]]:
        """
        按類別分組物件

        Args:
            confident_objects: 置信度過濾後的物件
            object_statistics: 物件統計信息

        Returns:
            Dict[str, List[Dict]]: 按類別分組的物件
        """
        objects_by_class = {}

        if object_statistics:
            # 使用預計算的統計信息,採用動態的信心度
            for class_name, stats in object_statistics.items():
                count = stats.get("count", 0)
                avg_confidence = stats.get("avg_confidence", 0)

                # 動態調整置信度閾值
                dynamic_threshold = self.confidence_threshold_for_description
                if class_name in ["potted plant", "vase", "clock", "book"]:
                    dynamic_threshold = max(0.15, self.confidence_threshold_for_description * 0.6)
                elif count >= 3:
                    dynamic_threshold = max(0.2, self.confidence_threshold_for_description * 0.8)

                if count > 0 and avg_confidence >= dynamic_threshold:
                    matching_objects = [obj for obj in confident_objects if obj.get("class_name") == class_name]
                    if not matching_objects:
                        matching_objects = [obj for obj in confident_objects
                                          if obj.get("class_name") == class_name and obj.get("confidence", 0) >= dynamic_threshold]

                    if matching_objects:
                        actual_count = min(stats["count"], len(matching_objects))
                        objects_by_class[class_name] = matching_objects[:actual_count]

                        # Debug logging for specific classes
                        if class_name in ["car", "traffic light", "person", "handbag"]:
                            print(f"DEBUG: Before spatial deduplication:")
                            print(f"DEBUG: {class_name}: {len(objects_by_class[class_name])} objects before dedup")
        else:
            # 備用邏輯,同樣使用動態閾值
            for obj in confident_objects:
                name = obj.get("class_name", "unknown object")
                if name == "unknown object" or not name:
                    continue
                if name not in objects_by_class:
                    objects_by_class[name] = []
                objects_by_class[name].append(obj)

        return objects_by_class

    def remove_duplicate_objects(self, objects_by_class: Dict[str, List[Dict]]) -> Dict[str, List[Dict]]:
        """
        移除重複物件

        Args:
            objects_by_class: 按類別分組的物件

        Returns:
            Dict[str, List[Dict]]: 去重後的物件
        """
        deduplicated_objects_by_class = {}
        processed_positions = []

        for class_name, group_of_objects in objects_by_class.items():
            unique_objects = []

            for obj in group_of_objects:
                obj_position = obj.get("normalized_center", [0.5, 0.5])
                is_duplicate = False

                for processed_pos in processed_positions:
                    position_distance = abs(obj_position[0] - processed_pos[0]) + abs(obj_position[1] - processed_pos[1])
                    if position_distance < 0.15:
                        is_duplicate = True
                        break

                if not is_duplicate:
                    unique_objects.append(obj)
                    processed_positions.append(obj_position)

            if unique_objects:
                deduplicated_objects_by_class[class_name] = unique_objects

        # Debug logging after deduplication
        for class_name in ["car", "traffic light", "person", "handbag"]:
            if class_name in deduplicated_objects_by_class:
                print(f"DEBUG: After spatial deduplication:")
                print(f"DEBUG: {class_name}: {len(deduplicated_objects_by_class[class_name])} objects after dedup")

        return deduplicated_objects_by_class

    def sort_object_groups(self, objects_by_class: Dict[str, List[Dict]]) -> List[Tuple[str, List[Dict]]]:
        """
        排序物件組

        Args:
            objects_by_class: 按類別分組的物件

        Returns:
            List[Tuple[str, List[Dict]]]: 排序後的物件組
        """
        def sort_key_object_groups(item_tuple: Tuple[str, List[Dict]]):
            class_name_key, obj_group_list = item_tuple
            priority = 3
            count = len(obj_group_list)

            # 確保類別名稱已標準化
            normalized_class_name = self._normalize_object_class_name(class_name_key)

            # 動態優先級
            if normalized_class_name == "person":
                priority = 0
            elif normalized_class_name in ["dining table", "chair", "sofa", "bed"]:
                priority = 1
            elif normalized_class_name in ["car", "bus", "truck", "traffic light"]:
                priority = 2
            elif count >= 3:
                priority = max(1, priority - 1)
            elif normalized_class_name in ["potted plant", "vase", "clock", "book"] and count >= 2:
                priority = 2

            avg_area = sum(o.get("normalized_area", 0.0) for o in obj_group_list) / len(obj_group_list) if obj_group_list else 0
            quantity_bonus = min(count / 5.0, 1.0)

            return (priority, -len(obj_group_list), -avg_area, -quantity_bonus)

        return sorted(objects_by_class.items(), key=sort_key_object_groups)

    def generate_object_clauses(self, sorted_object_groups: List[Tuple[str, List[Dict]]],
                               object_statistics: Optional[Dict],
                               scene_type: str,
                               image_width: Optional[int],
                               image_height: Optional[int],
                               region_analyzer: Optional[Any] = None) -> List[str]:
        """
        生成物件描述子句

        Args:
            sorted_object_groups: 排序後的物件組
            object_statistics: 物件統計信息
            scene_type: 場景類型
            image_width: 圖像寬度
            image_height: 圖像高度
            region_analyzer: 區域分析器實例

        Returns:
            List[str]: 物件描述子句列表
        """
        object_clauses = []

        for class_name, group_of_objects in sorted_object_groups:
            count = len(group_of_objects)

            # Debug logging for final count
            if class_name in ["car", "traffic light", "person", "handbag"]:
                print(f"DEBUG: Final count for {class_name}: {count}")

            if count == 0:
                continue

            # 標準化class name
            normalized_class_name = self._normalize_object_class_name(class_name)

            # 使用統計信息確保準確的數量描述
            if object_statistics and class_name in object_statistics:
                actual_count = object_statistics[class_name]["count"]
                formatted_name_with_exact_count = self._format_object_count_description(
                    normalized_class_name,
                    actual_count,
                    scene_type=scene_type
                )
            else:
                formatted_name_with_exact_count = self._format_object_count_description(
                    normalized_class_name,
                    count,
                    scene_type=scene_type
                )

            if formatted_name_with_exact_count == "no specific objects clearly identified" or not formatted_name_with_exact_count:
                continue

            # 確定群組的集體位置
            location_description_suffix = self._generate_location_description(
                group_of_objects, count, image_width, image_height, region_analyzer
            )

            # 首字母大寫
            formatted_name_capitalized = formatted_name_with_exact_count[0].upper() + formatted_name_with_exact_count[1:]
            object_clauses.append(f"{formatted_name_capitalized} {location_description_suffix}")

        return object_clauses

    def format_object_clauses(self, object_clauses: List[str]) -> str:
        """
        格式化物件描述子句

        Args:
            object_clauses: 物件描述子句列表

        Returns:
            str: 格式化後的描述
        """
        if not object_clauses:
            return "No common objects were confidently identified for detailed description."

        # 處理第一個子句
        first_clause = object_clauses.pop(0)
        result = first_clause + "."

        # 處理剩餘子句
        if object_clauses:
            result += " The scene features:"
            joined_object_clauses = ". ".join(object_clauses)
            if joined_object_clauses and not joined_object_clauses.endswith("."):
                joined_object_clauses += "."
            result += " " + joined_object_clauses

        return result

    def _generate_location_description(self, group_of_objects: List[Dict], count: int,
                                     image_width: Optional[int], image_height: Optional[int],
                                     region_analyzer: Optional[Any] = None) -> str:
        """
        生成位置描述

        Args:
            group_of_objects: 物件組
            count: 物件數量
            image_width: 圖像寬度
            image_height: 圖像高度
            region_analyzer: 區域分析器實例

        Returns:
            str: 位置描述
        """
        if count == 1:
            if self.spatial_handler:
                spatial_desc = self.spatial_handler.generate_spatial_description(
                    group_of_objects[0], image_width, image_height, region_analyzer
                )
            else:
                spatial_desc = self._get_spatial_description_phrase(group_of_objects[0].get("region", ""))

            if spatial_desc:
                return f"is {spatial_desc}"
            else:
                distinct_regions = sorted(list(set(obj.get("region", "") for obj in group_of_objects if obj.get("region"))))
                valid_regions = [r for r in distinct_regions if r and r != "unknown" and r.strip()]
                if not valid_regions:
                    return "is positioned in the scene"
                elif len(valid_regions) == 1:
                    spatial_desc = self._get_spatial_description_phrase(valid_regions[0])
                    return f"is primarily {spatial_desc}" if spatial_desc else "is positioned in the scene"
                elif len(valid_regions) == 2:
                    clean_region1 = valid_regions[0].replace('_', ' ')
                    clean_region2 = valid_regions[1].replace('_', ' ')
                    return f"is mainly across the {clean_region1} and {clean_region2} areas"
                else:
                    return "is distributed in various parts of the scene"
        else:
            distinct_regions = sorted(list(set(obj.get("region", "") for obj in group_of_objects if obj.get("region"))))
            valid_regions = [r for r in distinct_regions if r and r != "unknown" and r.strip()]
            if not valid_regions:
                return "are visible in the scene"
            elif len(valid_regions) == 1:
                clean_region = valid_regions[0].replace('_', ' ')
                return f"are primarily in the {clean_region} area"
            elif len(valid_regions) == 2:
                clean_region1 = valid_regions[0].replace('_', ' ')
                clean_region2 = valid_regions[1].replace('_', ' ')
                return f"are mainly across the {clean_region1} and {clean_region2} areas"
            else:
                return "are distributed in various parts of the scene"

    def _get_spatial_description_phrase(self, region: str) -> str:
        """
        獲取空間描述短語的備用方法

        Args:
            region: 區域字符串

        Returns:
            str: 空間描述短語
        """
        if not region or region == "unknown":
            return ""

        clean_region = region.replace('_', ' ').strip().lower()

        region_map = {
            "top left": "in the upper left area",
            "top center": "in the upper area",
            "top right": "in the upper right area",
            "middle left": "on the left side",
            "middle center": "in the center",
            "center": "in the center",
            "middle right": "on the right side",
            "bottom left": "in the lower left area",
            "bottom center": "in the lower area",
            "bottom right": "in the lower right area"
        }

        return region_map.get(clean_region, "")

    def _normalize_object_class_name(self, class_name: str) -> str:
        """
        標準化物件類別名稱

        Args:
            class_name: 原始類別名稱

        Returns:
            str: 標準化後的類別名稱
        """
        if self.text_optimizer:
            return self.text_optimizer.normalize_object_class_name(class_name)
        else:
            # 備用標準化邏輯
            if not class_name or not isinstance(class_name, str):
                return "object"

            # 簡單的標準化處理
            normalized = class_name.replace('_', ' ').strip().lower()
            return normalized

    def _format_object_count_description(self, class_name: str, count: int,
                                       scene_type: Optional[str] = None,
                                       detected_objects: Optional[List[Dict]] = None,
                                       avg_confidence: float = 0.0) -> str:
        """
        格式化物件數量描述

        Args:
            class_name: 標準化後的類別名稱
            count: 物件數量
            scene_type: 場景類型
            detected_objects: 該類型的所有檢測物件
            avg_confidence: 平均檢測置信度

        Returns:
            str: 完整的格式化數量描述
        """
        if self.text_optimizer:
            return self.text_optimizer.format_object_count_description(
                class_name, count, scene_type, detected_objects, avg_confidence
            )
        else:
            # 備用格式化邏輯
            if count <= 0:
                return ""
            elif count == 1:
                article = "an" if class_name[0].lower() in 'aeiou' else "a"
                return f"{article} {class_name}"
            else:
                # 簡單的複數處理
                plural_form = class_name + "s" if not class_name.endswith("s") else class_name

                number_words = {
                    2: "two", 3: "three", 4: "four", 5: "five", 6: "six",
                    7: "seven", 8: "eight", 9: "nine", 10: "ten",
                    11: "eleven", 12: "twelve"
                }

                if count in number_words:
                    return f"{number_words[count]} {plural_form}"
                elif count <= 20:
                    return f"several {plural_form}"
                else:
                    return f"numerous {plural_form}"