import logging import traceback from typing import Dict, List, Tuple, Optional import numpy as np class ViewpointDetectionError(Exception): """Custom exception for errors during viewpoint detection.""" pass class ViewpointDetector: """ 視角檢測器 - 分析物體分布模式以識別圖像視角類型 此class負責通過分析檢測到的物體在圖像中的空間分布、大小變化和位置模式, 來確定圖像的拍攝視角。特別針對行人密集的十字路口場景進行了優化。 """ def __init__(self, aerial_threshold: float = 0.7, aerial_size_variance_threshold: float = 0.15, low_angle_threshold: float = 0.3, vertical_size_ratio_threshold: float = 1.8, elevated_threshold: float = 0.6, elevated_top_threshold: float = 0.3, crosswalk_position_tolerance: float = 0.1, crosswalk_axis_tolerance: float = 0.15, min_people_for_crosswalk: int = 8, min_people_for_aerial: int = 10): """ 初始化視角檢測器 Args: aerial_threshold: 空中視角檢測的物體密度閾值 aerial_size_variance_threshold: 空中視角的大小變異閾值 low_angle_threshold: 低角度視角的底部分布閾值 vertical_size_ratio_threshold: 垂直大小比例閾值 elevated_threshold: 高位視角的物體分布閾值 elevated_top_threshold: 高位視角的頂部物體閾值 crosswalk_position_tolerance: 十字路口位置容差 crosswalk_axis_tolerance: 十字路口軸線容差 min_people_for_crosswalk: 檢測十字路口所需的最少人數 min_people_for_aerial: 檢測空中視角所需的最少人數 """ self.logger = logging.getLogger(self.__class__.__name__) # 視角檢測參數配置 self.viewpoint_params = { "aerial_threshold": aerial_threshold, "aerial_size_variance_threshold": aerial_size_variance_threshold, "low_angle_threshold": low_angle_threshold, "vertical_size_ratio_threshold": vertical_size_ratio_threshold, "elevated_threshold": elevated_threshold, "elevated_top_threshold": elevated_top_threshold, "crosswalk_position_tolerance": crosswalk_position_tolerance, "crosswalk_axis_tolerance": crosswalk_axis_tolerance, "min_people_for_crosswalk": min_people_for_crosswalk, "min_people_for_aerial": min_people_for_aerial } self.logger.info("ViewpointDetector initialized with parameters: %s", self.viewpoint_params) def detect_viewpoint(self, detected_objects: List[Dict]) -> str: """ 檢測圖像視角類型 Args: detected_objects: 檢測到的物體列表,每個物體應包含位置、大小等信息 Returns: str: 檢測到的視角類型 ('aerial', 'low_angle', 'elevated', 'eye_level') """ try: if not detected_objects: self.logger.warning("No detected objects provided for viewpoint detection") return "eye_level" self.logger.info(f"Starting viewpoint detection with {len(detected_objects)} objects") # 優先檢測十字路口模式(通常為空中視角) if self._detect_crosswalk_pattern(detected_objects): self.logger.info("Crosswalk pattern detected - returning aerial viewpoint") return "aerial" # 檢測基於行人分布的空中視角 if self._detect_aerial_from_pedestrian_distribution(detected_objects): self.logger.info("Aerial viewpoint detected from pedestrian distribution") return "aerial" # 標準視角檢測流程 return self._detect_standard_viewpoint(detected_objects) except Exception as e: error_msg = f"Error during viewpoint detection: {str(e)}" self.logger.error(f"{error_msg}\n{traceback.format_exc()}") return "eye_level" # 返回默認值 def _detect_crosswalk_pattern(self, detected_objects: List[Dict]) -> bool: """ 檢測十字路口/斑馬線模式 Args: detected_objects: 檢測到的物體列表 Returns: bool: 是否檢測到十字路口模式 """ try: people_objs = [obj for obj in detected_objects if obj.get("class_id") == 0] if len(people_objs) < self.viewpoint_params["min_people_for_crosswalk"]: return False # 提取行人位置 people_positions = [] for obj in people_objs: if "normalized_center" in obj: people_positions.append(obj["normalized_center"]) if len(people_positions) < 4: return False # 檢測十字形分布 if self._detect_cross_pattern(people_positions): self.logger.debug("Cross pattern detected in pedestrian positions") return True # 檢測線性聚類分布 if self._detect_linear_crosswalk_clusters(people_positions): self.logger.debug("Linear crosswalk clusters detected") return True return False except Exception as e: self.logger.warning(f"Error in crosswalk pattern detection: {str(e)}") return False def _detect_cross_pattern(self, positions: List[Tuple[float, float]]) -> bool: """ 檢測十字形分布模式 Args: positions: 物體位置列表 [(x, y), ...] Returns: bool: 是否檢測到十字形模式 """ try: x_coords = [pos[0] for pos in positions] y_coords = [pos[1] for pos in positions] x_range = max(x_coords) - min(x_coords) y_range = max(y_coords) - min(y_coords) # 檢查 x 和 y 方向都有較大範圍且範圍相似 if x_range <= 0.5 or y_range <= 0.5: return False if not (0.7 < (x_range / y_range) < 1.3): return False # 計算到中心點的距離並檢查軸線分布 center_x = np.mean(x_coords) center_y = np.mean(y_coords) close_to_axis_count = 0 axis_tolerance = self.viewpoint_params["crosswalk_axis_tolerance"] for x, y in positions: x_distance_to_center = abs(x - center_x) y_distance_to_center = abs(y - center_y) # 檢查是否接近水平或垂直軸線 if x_distance_to_center < axis_tolerance or y_distance_to_center < axis_tolerance: close_to_axis_count += 1 # 如果足夠多的點接近軸線,認為是十字路口 axis_ratio = close_to_axis_count / len(positions) return axis_ratio >= 0.6 except Exception as e: self.logger.warning(f"Error detecting cross pattern: {str(e)}") return False def _detect_linear_crosswalk_clusters(self, positions: List[Tuple[float, float]]) -> bool: """ 檢測線性聚類分布(交叉的斑馬線) Args: positions: 物體位置列表 Returns: bool: 是否檢測到線性交叉模式 """ try: x_coords = [pos[0] for pos in positions] y_coords = [pos[1] for pos in positions] # 檢測 x 和 y 方向的聚類 x_clusters = self._detect_linear_clusters(x_coords) y_clusters = self._detect_linear_clusters(y_coords) # 如果在 x 和 y 方向上都有多個聚類,可能是交叉的斑馬線 return len(x_clusters) >= 2 and len(y_clusters) >= 2 except Exception as e: self.logger.warning(f"Error detecting linear crosswalk clusters: {str(e)}") return False def _detect_linear_clusters(self, coords: List[float], threshold: float = 0.05) -> List[List[float]]: """ 檢測坐標中的線性聚類 Args: coords: 一維坐標列表 threshold: 聚類閾值 Returns: List[List[float]]: 聚類列表 """ if not coords: return [] try: sorted_coords = sorted(coords) clusters = [] current_cluster = [sorted_coords[0]] for i in range(1, len(sorted_coords)): if sorted_coords[i] - sorted_coords[i-1] < threshold: current_cluster.append(sorted_coords[i]) else: if len(current_cluster) >= 2: clusters.append(current_cluster) current_cluster = [sorted_coords[i]] # 添加最後一個聚類 if len(current_cluster) >= 2: clusters.append(current_cluster) return clusters except Exception as e: self.logger.warning(f"Error in linear cluster detection: {str(e)}") return [] def _detect_aerial_from_pedestrian_distribution(self, detected_objects: List[Dict]) -> bool: """ 基於行人分布檢測空中視角 Args: detected_objects: 檢測到的物體列表 Returns: bool: 是否為空中視角 """ try: people_objs = [obj for obj in detected_objects if obj.get("class_id") == 0] if len(people_objs) < self.viewpoint_params["min_people_for_aerial"]: return False # 統計不同區域的行人數量 people_region_counts = {} for obj in people_objs: region = obj.get("region", "unknown") people_region_counts[region] = people_region_counts.get(region, 0) + 1 # 檢查行人是否分布在多個區域 regions_with_multiple_people = sum(1 for count in people_region_counts.values() if count >= 2) if regions_with_multiple_people < 4: return False # 檢查行人分布的均勻性 region_counts = list(people_region_counts.values()) if not region_counts: return False region_counts_variance = np.var(region_counts) region_counts_mean = np.mean(region_counts) if region_counts_mean > 0: variation_coefficient = region_counts_variance / region_counts_mean return variation_coefficient < 0.5 return False except Exception as e: self.logger.warning(f"Error in aerial detection from pedestrian distribution: {str(e)}") return False def _detect_standard_viewpoint(self, detected_objects: List[Dict]) -> str: """ 標準視角檢測流程 Args: detected_objects: 檢測到的物體列表 Returns: str: 檢測到的視角類型 """ try: # 計算基本統計指標 metrics = self._calculate_viewpoint_metrics(detected_objects) # 基於計算的指標判斷視角類型 if self._is_aerial_viewpoint(metrics): return "aerial" elif self._is_low_angle_viewpoint(metrics): return "low_angle" elif self._is_elevated_viewpoint(metrics): return "elevated" else: return "eye_level" except Exception as e: self.logger.warning(f"Error in standard viewpoint detection: {str(e)}") return "eye_level" def _calculate_viewpoint_metrics(self, detected_objects: List[Dict]) -> Dict: """ 計算視角檢測所需的各項指標 Args: detected_objects: 檢測到的物體列表 Returns: Dict: 包含各項指標的字典 """ total_objects = len(detected_objects) top_region_count = 0 bottom_region_count = 0 sizes = [] height_width_ratios = [] try: for obj in detected_objects: # 統計頂部和底部區域的物體數量 region = obj.get("region", "") if "top" in region: top_region_count += 1 elif "bottom" in region: bottom_region_count += 1 # 收集大小信息 if "normalized_area" in obj: sizes.append(obj["normalized_area"]) # 計算高寬比 if "normalized_size" in obj: width, height = obj["normalized_size"] if width > 0: height_width_ratios.append(height / width) # 計算比例 top_ratio = top_region_count / total_objects if total_objects > 0 else 0 bottom_ratio = bottom_region_count / total_objects if total_objects > 0 else 0 # 計算大小變異係數 size_variance_coefficient = 0 if sizes and len(sizes) > 1: mean_size = np.mean(sizes) if mean_size > 0: size_variance = np.var(sizes) size_variance_coefficient = size_variance / (mean_size ** 2) # 計算平均高寬比 avg_height_width_ratio = np.mean(height_width_ratios) if height_width_ratios else 1.0 metrics = { "top_ratio": top_ratio, "bottom_ratio": bottom_ratio, "size_variance_coefficient": size_variance_coefficient, "avg_height_width_ratio": avg_height_width_ratio, "total_objects": total_objects } self.logger.debug(f"Calculated viewpoint metrics: {metrics}") return metrics except Exception as e: self.logger.error(f"Error calculating viewpoint metrics: {str(e)}") return { "top_ratio": 0, "bottom_ratio": 0, "size_variance_coefficient": 0, "avg_height_width_ratio": 1.0, "total_objects": total_objects } def _is_aerial_viewpoint(self, metrics: Dict) -> bool: """判斷是否為空中視角""" return (metrics["size_variance_coefficient"] < self.viewpoint_params["aerial_size_variance_threshold"] and metrics["bottom_ratio"] < 0.3 and metrics["top_ratio"] > self.viewpoint_params["aerial_threshold"]) def _is_low_angle_viewpoint(self, metrics: Dict) -> bool: """判斷是否為低角度視角""" return (metrics["avg_height_width_ratio"] > self.viewpoint_params["vertical_size_ratio_threshold"] and metrics["top_ratio"] > self.viewpoint_params["low_angle_threshold"]) def _is_elevated_viewpoint(self, metrics: Dict) -> bool: """判斷是否為高位視角""" return (metrics["bottom_ratio"] > self.viewpoint_params["elevated_threshold"] and metrics["top_ratio"] < self.viewpoint_params["elevated_top_threshold"]) def get_viewpoint_confidence(self, detected_objects: List[Dict]) -> Tuple[str, float]: """ 獲取視角檢測結果及其信心度 Args: detected_objects: 檢測到的物體列表 Returns: Tuple[str, float]: (視角類型, 信心度) """ try: viewpoint = self.detect_viewpoint(detected_objects) # 基於檢測條件計算信心度 if viewpoint == "aerial" and self._detect_crosswalk_pattern(detected_objects): confidence = 0.95 # 十字路口模式有很高信心度 elif viewpoint == "aerial": confidence = 0.8 elif viewpoint == "eye_level": confidence = 0.7 # 默認視角信心度較低 else: confidence = 0.85 self.logger.info(f"Viewpoint detection result: {viewpoint} (confidence: {confidence:.2f})") return viewpoint, confidence except Exception as e: self.logger.warning("Using fallback viewpoint due to detection error") return "eye_level", 0.3