Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import clip | |
import numpy as np | |
from PIL import Image | |
from typing import Dict, List, Tuple, Any, Optional, Union | |
from clip_prompts import ( | |
SCENE_TYPE_PROMPTS, | |
CULTURAL_SCENE_PROMPTS, | |
COMPARATIVE_PROMPTS, | |
LIGHTING_CONDITION_PROMPTS, | |
SPECIALIZED_SCENE_PROMPTS, | |
VIEWPOINT_PROMPTS, | |
OBJECT_COMBINATION_PROMPTS, | |
ACTIVITY_PROMPTS | |
) | |
class CLIPAnalyzer: | |
""" | |
Use Clip to intergrate scene understanding function | |
""" | |
def __init__(self, model_name: str = "ViT-B/16", device: str = None): | |
""" | |
初始化 CLIP 分析器。 | |
Args: | |
model_name: CLIP Model name, 默認 "ViT-L/14" | |
device: Use GPU if it can use | |
""" | |
# 自動選擇設備 | |
if device is None: | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
else: | |
self.device = device | |
print(f"Loading CLIP model {model_name} on {self.device}...") | |
try: | |
self.model, self.preprocess = clip.load(model_name, device=self.device) | |
print(f"CLIP model loaded successfully.") | |
except Exception as e: | |
print(f"Error loading CLIP model: {e}") | |
raise | |
self.scene_type_prompts = SCENE_TYPE_PROMPTS | |
self.cultural_scene_prompts = CULTURAL_SCENE_PROMPTS | |
self.comparative_prompts = COMPARATIVE_PROMPTS | |
self.lighting_condition_prompts = LIGHTING_CONDITION_PROMPTS | |
self.specialized_scene_prompts = SPECIALIZED_SCENE_PROMPTS | |
self.viewpoint_prompts = VIEWPOINT_PROMPTS | |
self.object_combination_prompts = OBJECT_COMBINATION_PROMPTS | |
self.activity_prompts = ACTIVITY_PROMPTS | |
# turn to CLIP format | |
self._prepare_text_prompts() | |
def _prepare_text_prompts(self): | |
"""準備所有文本提示的 CLIP 特徵並存儲到 self.text_features_cache 中""" | |
self.text_features_cache = {} | |
# 處理基礎場景類型 (SCENE_TYPE_PROMPTS) | |
if hasattr(self, 'scene_type_prompts') and self.scene_type_prompts: | |
scene_texts = [prompt for scene_type, prompt in self.scene_type_prompts.items()] | |
if scene_texts: | |
self.text_features_cache["scene_type_keys"] = list(self.scene_type_prompts.keys()) | |
try: | |
self.text_features_cache["scene_type_tokens"] = clip.tokenize(scene_texts).to(self.device) | |
except Exception as e: | |
print(f"Warning: Error tokenizing scene_type_prompts: {e}") | |
self.text_features_cache["scene_type_tokens"] = None # 標記錯誤或空 | |
else: | |
self.text_features_cache["scene_type_keys"] = [] | |
self.text_features_cache["scene_type_tokens"] = None | |
else: | |
self.text_features_cache["scene_type_keys"] = [] | |
self.text_features_cache["scene_type_tokens"] = None | |
# 處理文化場景 (CULTURAL_SCENE_PROMPTS) | |
# cultural_tokens_dict 存儲的是 tokenized prompts | |
cultural_tokens_dict_val = {} | |
if hasattr(self, 'cultural_scene_prompts') and self.cultural_scene_prompts: | |
for scene_type, prompts in self.cultural_scene_prompts.items(): | |
if prompts and isinstance(prompts, list) and all(isinstance(p, str) for p in prompts): | |
try: | |
cultural_tokens_dict_val[scene_type] = clip.tokenize(prompts).to(self.device) | |
except Exception as e: | |
print(f"Warning: Error tokenizing cultural_scene_prompts for {scene_type}: {e}") | |
cultural_tokens_dict_val[scene_type] = None # 標記錯誤或空 | |
else: | |
cultural_tokens_dict_val[scene_type] = None # prompts 不合規 | |
self.text_features_cache["cultural_tokens_dict"] = cultural_tokens_dict_val | |
# 處理光照條件 (LIGHTING_CONDITION_PROMPTS) | |
if hasattr(self, 'lighting_condition_prompts') and self.lighting_condition_prompts: | |
lighting_texts = [prompt for cond, prompt in self.lighting_condition_prompts.items()] | |
if lighting_texts: | |
self.text_features_cache["lighting_condition_keys"] = list(self.lighting_condition_prompts.keys()) | |
try: | |
self.text_features_cache["lighting_tokens"] = clip.tokenize(lighting_texts).to(self.device) | |
except Exception as e: | |
print(f"Warning: Error tokenizing lighting_condition_prompts: {e}") | |
self.text_features_cache["lighting_tokens"] = None | |
else: | |
self.text_features_cache["lighting_condition_keys"] = [] | |
self.text_features_cache["lighting_tokens"] = None | |
else: | |
self.text_features_cache["lighting_condition_keys"] = [] | |
self.text_features_cache["lighting_tokens"] = None | |
# 處理特殊場景 (SPECIALIZED_SCENE_PROMPTS) | |
specialized_tokens_dict_val = {} | |
if hasattr(self, 'specialized_scene_prompts') and self.specialized_scene_prompts: | |
for scene_type, prompts in self.specialized_scene_prompts.items(): | |
if prompts and isinstance(prompts, list) and all(isinstance(p, str) for p in prompts): | |
try: | |
specialized_tokens_dict_val[scene_type] = clip.tokenize(prompts).to(self.device) | |
except Exception as e: | |
print(f"Warning: Error tokenizing specialized_scene_prompts for {scene_type}: {e}") | |
specialized_tokens_dict_val[scene_type] = None | |
else: | |
specialized_tokens_dict_val[scene_type] = None | |
self.text_features_cache["specialized_tokens_dict"] = specialized_tokens_dict_val | |
# 處理視角 (VIEWPOINT_PROMPTS) | |
if hasattr(self, 'viewpoint_prompts') and self.viewpoint_prompts: | |
viewpoint_texts = [prompt for viewpoint, prompt in self.viewpoint_prompts.items()] | |
if viewpoint_texts: | |
self.text_features_cache["viewpoint_keys"] = list(self.viewpoint_prompts.keys()) | |
try: | |
self.text_features_cache["viewpoint_tokens"] = clip.tokenize(viewpoint_texts).to(self.device) | |
except Exception as e: | |
print(f"Warning: Error tokenizing viewpoint_prompts: {e}") | |
self.text_features_cache["viewpoint_tokens"] = None | |
else: | |
self.text_features_cache["viewpoint_keys"] = [] | |
self.text_features_cache["viewpoint_tokens"] = None | |
else: | |
self.text_features_cache["viewpoint_keys"] = [] | |
self.text_features_cache["viewpoint_tokens"] = None | |
# 處理物件組合 (OBJECT_COMBINATION_PROMPTS) | |
if hasattr(self, 'object_combination_prompts') and self.object_combination_prompts: | |
object_combination_texts = [prompt for combo, prompt in self.object_combination_prompts.items()] | |
if object_combination_texts: | |
self.text_features_cache["object_combination_keys"] = list(self.object_combination_prompts.keys()) | |
try: | |
self.text_features_cache["object_combination_tokens"] = clip.tokenize(object_combination_texts).to(self.device) | |
except Exception as e: | |
print(f"Warning: Error tokenizing object_combination_prompts: {e}") | |
self.text_features_cache["object_combination_tokens"] = None | |
else: | |
self.text_features_cache["object_combination_keys"] = [] | |
self.text_features_cache["object_combination_tokens"] = None | |
else: | |
self.text_features_cache["object_combination_keys"] = [] | |
self.text_features_cache["object_combination_tokens"] = None | |
# 處理活動 (ACTIVITY_PROMPTS) | |
if hasattr(self, 'activity_prompts') and self.activity_prompts: | |
activity_texts = [prompt for activity, prompt in self.activity_prompts.items()] | |
if activity_texts: | |
self.text_features_cache["activity_keys"] = list(self.activity_prompts.keys()) | |
try: | |
self.text_features_cache["activity_tokens"] = clip.tokenize(activity_texts).to(self.device) | |
except Exception as e: | |
print(f"Warning: Error tokenizing activity_prompts: {e}") | |
self.text_features_cache["activity_tokens"] = None | |
else: | |
self.text_features_cache["activity_keys"] = [] | |
self.text_features_cache["activity_tokens"] = None | |
else: | |
self.text_features_cache["activity_keys"] = [] | |
self.text_features_cache["activity_tokens"] = None | |
self.scene_type_tokens = self.text_features_cache["scene_type_tokens"] | |
self.lighting_tokens = self.text_features_cache["lighting_tokens"] | |
self.viewpoint_tokens = self.text_features_cache["viewpoint_tokens"] | |
self.object_combination_tokens = self.text_features_cache["object_combination_tokens"] | |
self.activity_tokens = self.text_features_cache["activity_tokens"] | |
self.cultural_tokens_dict = self.text_features_cache["cultural_tokens_dict"] | |
self.specialized_tokens_dict = self.text_features_cache["specialized_tokens_dict"] | |
print("CLIP text_features_cache prepared.") | |
def analyze_image(self, image, include_cultural_analysis=True, exclude_categories=None, enable_landmark=True, places365_guidance=None): | |
""" | |
分析圖像,預測場景類型和光照條件。 | |
Args: | |
image: 輸入圖像 (PIL Image 或 numpy array) | |
include_cultural_analysis: 是否包含文化場景的詳細分析 | |
exclude_categories: 要排除的類別列表 | |
enable_landmark: 是否啟用地標檢測功能 | |
places365_guidance: Places365 提供的場景指導信息 (可選) | |
Returns: | |
Dict: 包含場景類型預測和光照條件的分析結果 | |
""" | |
try: | |
self.enable_landmark = enable_landmark # 更新實例的 enable_landmark 狀態 | |
# 確保圖像是 PIL 格式 | |
if not isinstance(image, Image.Image): | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
else: | |
raise ValueError("Unsupported image format. Expected PIL Image or numpy array.") | |
# 預處理圖像 | |
image_input = self.preprocess(image).unsqueeze(0).to(self.device) | |
# 獲取圖像特徵 | |
with torch.no_grad(): | |
image_features = self.model.encode_image(image_input) | |
image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
places365_focus_areas = [] | |
places365_scene_context = "" # 用於存儲 Places365 提供的場景描述 | |
if places365_guidance and isinstance(places365_guidance, dict) and places365_guidance.get('confidence', 0) > 0.4: | |
mapped_scene = places365_guidance.get('mapped_scene_type', '') | |
scene_label = places365_guidance.get('scene_label', '') | |
# is_indoor = places365_guidance.get('is_indoor', None) # 未使用,可註釋 | |
attributes = places365_guidance.get('attributes', []) | |
places365_scene_context = f"Scene identified by Places365 as {scene_label}" # 更新上下文描述 | |
# Adjust CLIP analysis focus based on Places365 scene type | |
if mapped_scene in ['kitchen', 'dining_area', 'restaurant']: | |
places365_focus_areas.extend(['food preparation', 'dining setup', 'kitchen appliances']) | |
elif mapped_scene in ['office_workspace', 'educational_setting', 'library', 'conference_room']: | |
places365_focus_areas.extend(['work environment', 'professional setting', 'learning space', 'study area']) | |
elif mapped_scene in ['retail_store', 'shopping_mall', 'market', 'supermarket']: # 擴展匹配 | |
places365_focus_areas.extend(['commercial space', 'shopping environment', 'retail display', 'goods for sale']) | |
elif mapped_scene in ['park_area', 'beach', 'natural_outdoor_area', 'playground', 'sports_field']: # 擴展匹配 | |
places365_focus_areas.extend(['outdoor recreation', 'natural environment', 'leisure activity', 'open space']) | |
# 根據屬性添加更通用的 focus areas | |
if isinstance(attributes, list): # 確保 attributes 是列表 | |
if 'commercial' in attributes: | |
places365_focus_areas.append('business activity') | |
if 'recreational' in attributes: | |
places365_focus_areas.append('entertainment or leisure') | |
if 'residential' in attributes: | |
places365_focus_areas.append('living space') | |
# 去重 | |
places365_focus_areas = list(set(places365_focus_areas)) | |
if places365_focus_areas: # 只有在確實有 focus areas 時才打印 | |
print(f"CLIP analysis guided by Places365: {places365_scene_context}, focus areas: {places365_focus_areas}") | |
# 分析場景類型,傳遞 enable_landmark 參數和 Places365 指導 | |
scene_scores = self._analyze_scene_type(image_features, | |
enable_landmark=self.enable_landmark, # 使用更新後的實例屬性 | |
places365_focus=places365_focus_areas) | |
# 如果禁用地標功能,確保排除地標相關類別 | |
current_exclude_categories = list(exclude_categories) if exclude_categories is not None else [] | |
if not self.enable_landmark: # 使用更新後的實例屬性 | |
landmark_related_terms = ["landmark", "monument", "tower", "tourist", "attraction", "historical", "famous", "iconic"] | |
for term in landmark_related_terms: | |
if term not in current_exclude_categories: | |
current_exclude_categories.append(term) | |
if current_exclude_categories: | |
filtered_scores = {} | |
for scene, score in scene_scores.items(): | |
# 檢查 scene 的鍵名(通常是英文)是否包含任何排除詞彙 | |
if not any(cat.lower() in scene.lower() for cat in current_exclude_categories): | |
filtered_scores[scene] = score | |
if filtered_scores: | |
total_score = sum(filtered_scores.values()) | |
if total_score > 1e-5: # 避免除以零或非常小的數 | |
scene_scores = {k: v / total_score for k, v in filtered_scores.items()} | |
else: # 如果總分趨近於0,則保持原樣或設為0 | |
scene_scores = {k: 0.0 for k in filtered_scores.keys()} # 或者 scene_scores = filtered_scores | |
else: # 如果過濾後沒有場景了 | |
scene_scores = {k: (0.0 if any(cat.lower() in k.lower() for cat in current_exclude_categories) else v) for k,v in scene_scores.items()} | |
if not any(s > 1e-5 for s in scene_scores.values()): # 如果還是全0 | |
scene_scores = {"unknown": 1.0} # 給一個默認值避免空字典 | |
lighting_scores = self._analyze_lighting_condition(image_features) | |
cultural_analysis = {} | |
if include_cultural_analysis and self.enable_landmark: # 使用更新後的實例屬性 | |
for scene_type_cultural_key in self.text_features_cache.get("cultural_tokens_dict", {}).keys(): | |
# 確保 scene_type_cultural_key 是 SCENE_TYPE_PROMPTS 中的鍵,或者有一個映射關係 | |
if scene_type_cultural_key in scene_scores and scene_scores[scene_type_cultural_key] > 0.2: | |
cultural_analysis[scene_type_cultural_key] = self._analyze_cultural_scene( | |
image_features, scene_type_cultural_key | |
) | |
specialized_analysis = {} | |
for scene_type_specialized_key in self.text_features_cache.get("specialized_tokens_dict", {}).keys(): | |
if scene_type_specialized_key in scene_scores and scene_scores[scene_type_specialized_key] > 0.2: | |
specialized_analysis[scene_type_specialized_key] = self._analyze_specialized_scene( | |
image_features, scene_type_specialized_key | |
) | |
viewpoint_scores = self._analyze_viewpoint(image_features) | |
object_combination_scores = self._analyze_object_combinations(image_features) | |
activity_scores = self._analyze_activities(image_features) | |
if scene_scores: # 確保 scene_scores 不是空的 | |
top_scene = max(scene_scores.items(), key=lambda x: x[1]) | |
# 如果禁用地標,再次確認 top_scene 不是地標相關 | |
if not self.enable_landmark and any(cat.lower() in top_scene[0].lower() for cat in current_exclude_categories): | |
non_excluded_scores = {k:v for k,v in scene_scores.items() if not any(cat.lower() in k.lower() for cat in current_exclude_categories)} | |
if non_excluded_scores: | |
top_scene = max(non_excluded_scores.items(), key=lambda x: x[1]) | |
else: | |
top_scene = ("unknown", 0.0) # 或其他合適的默認值 | |
else: | |
top_scene = ("unknown", 0.0) | |
result = { | |
"scene_scores": scene_scores, | |
"top_scene": top_scene, | |
"lighting_condition": max(lighting_scores.items(), key=lambda x: x[1]) if lighting_scores else ("unknown", 0.0), | |
"embedding": image_features.cpu().numpy().tolist()[0], # 簡化 | |
"viewpoint": max(viewpoint_scores.items(), key=lambda x: x[1]) if viewpoint_scores else ("unknown", 0.0), | |
"object_combinations": sorted(object_combination_scores.items(), key=lambda x: x[1], reverse=True)[:3] if object_combination_scores else [], | |
"activities": sorted(activity_scores.items(), key=lambda x: x[1], reverse=True)[:3] if activity_scores else [] | |
} | |
if places365_guidance and isinstance(places365_guidance, dict) and places365_focus_areas: # 檢查 places365_focus_areas 是否被填充 | |
result["places365_guidance"] = { | |
"scene_context": places365_scene_context, | |
"focus_areas": places365_focus_areas, # 現在這個會包含基於 guidance 的內容 | |
"guided_analysis": True, | |
"original_places365_scene": places365_guidance.get('scene_label', 'N/A'), | |
"original_places365_confidence": places365_guidance.get('confidence', 0.0) | |
} | |
if cultural_analysis and self.enable_landmark: | |
result["cultural_analysis"] = cultural_analysis | |
if specialized_analysis: | |
result["specialized_analysis"] = specialized_analysis | |
return result | |
except Exception as e: | |
print(f"Error analyzing image with CLIP: {e}") | |
import traceback | |
traceback.print_exc() | |
return {"error": str(e), "scene_scores": {}, "top_scene": ("error", 0.0)} | |
def _analyze_scene_type(self, image_features: torch.Tensor, enable_landmark: bool = True, places365_focus: List[str] = None) -> Dict[str, float]: | |
""" | |
分析圖像特徵與各場景類型的相似度,並可選擇性地排除地標相關場景 | |
Args: | |
image_features: 經過 CLIP 編碼的圖像特徵 | |
enable_landmark: 是否啟用地標識別功能 | |
Returns: | |
Dict[str, float]: 各場景類型的相似度分數字典 | |
""" | |
with torch.no_grad(): | |
# 計算場景類型文本特徵 | |
text_features = self.model.encode_text(self.scene_type_tokens) | |
text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
# Apply Places365 guidance if available | |
if places365_focus and len(places365_focus) > 0: | |
# Create enhanced prompts that incorporate Places365 guidance | |
enhanced_prompts = [] | |
for scene_type in self.scene_type_prompts.keys(): | |
base_prompt = self.scene_type_prompts[scene_type] | |
# Check if this scene type should be emphasized based on Places365 guidance | |
scene_lower = scene_type.lower() | |
should_enhance = False | |
for focus_area in places365_focus: | |
if any(keyword in scene_lower for keyword in focus_area.split()): | |
should_enhance = True | |
enhanced_prompts.append(f"{base_prompt} with {focus_area}") | |
break | |
if not should_enhance: | |
enhanced_prompts.append(base_prompt) | |
# Re-tokenize and encode enhanced prompts | |
enhanced_tokens = clip.tokenize(enhanced_prompts).to(self.device) | |
text_features = self.model.encode_text(enhanced_tokens) | |
text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
# 計算相似度分數 | |
similarity = (100 * image_features @ text_features.T).softmax(dim=-1) | |
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0] | |
# 建立場景分數字典 | |
scene_scores = {} | |
for i, scene_type in enumerate(self.scene_type_prompts.keys()): | |
# 如果未啟用地標功能,則跳過地標相關場景類型 | |
if not enable_landmark and scene_type in ["tourist_landmark", "natural_landmark", "historical_monument"]: | |
scene_scores[scene_type] = 0.0 # 將地標場景分數設為零 | |
else: | |
base_score = float(similarity[i]) | |
# Apply Places365 guidance boost if applicable | |
if places365_focus: | |
scene_lower = scene_type.lower() | |
boost_factor = 1.0 | |
for focus_area in places365_focus: | |
if any(keyword in scene_lower for keyword in focus_area.split()): | |
boost_factor = 1.15 # 15% boost for matching scenes | |
break | |
scene_scores[scene_type] = base_score * boost_factor | |
else: | |
scene_scores[scene_type] = base_score | |
# 如果禁用地標功能,確保重新歸一化剩餘場景分數 | |
if not enable_landmark: | |
# 獲取所有非零分數 | |
non_zero_scores = {k: v for k, v in scene_scores.items() if v > 0} | |
if non_zero_scores: | |
# 計算總和並歸一化 | |
total_score = sum(non_zero_scores.values()) | |
if total_score > 0: | |
for scene_type in non_zero_scores: | |
scene_scores[scene_type] = non_zero_scores[scene_type] / total_score | |
return scene_scores | |
def _analyze_lighting_condition(self, image_features: torch.Tensor) -> Dict[str, float]: | |
"""分析圖像的光照條件""" | |
with torch.no_grad(): | |
# 計算光照條件文本特徵 | |
text_features = self.model.encode_text(self.lighting_tokens) | |
text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
# 計算相似度分數 | |
similarity = (100 * image_features @ text_features.T).softmax(dim=-1) | |
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0] | |
# 建立光照條件分數字典 | |
lighting_scores = {} | |
for i, lighting_type in enumerate(self.lighting_condition_prompts.keys()): | |
lighting_scores[lighting_type] = float(similarity[i]) | |
return lighting_scores | |
def _analyze_cultural_scene(self, image_features: torch.Tensor, scene_type: str) -> Dict[str, Any]: | |
"""針對特定文化場景進行深入分析""" | |
if scene_type not in self.cultural_tokens_dict: | |
return {"error": f"No cultural analysis available for {scene_type}"} | |
with torch.no_grad(): | |
# 獲取特定文化場景的文本特徵 | |
cultural_tokens = self.cultural_tokens_dict[scene_type] | |
text_features = self.model.encode_text(cultural_tokens) | |
text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
# 計算相似度分數 | |
similarity = (100 * image_features @ text_features.T) | |
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0] | |
# 找到最匹配的文化描述 | |
prompts = self.cultural_scene_prompts[scene_type] | |
scores = [(prompts[i], float(similarity[i])) for i in range(len(prompts))] | |
scores.sort(key=lambda x: x[1], reverse=True) | |
return { | |
"best_description": scores[0][0], | |
"confidence": scores[0][1], | |
"all_matches": scores | |
} | |
def _analyze_specialized_scene(self, image_features: torch.Tensor, scene_type: str) -> Dict[str, Any]: | |
"""針對特定專門場景進行深入分析""" | |
if scene_type not in self.specialized_tokens_dict: | |
return {"error": f"No specialized analysis available for {scene_type}"} | |
with torch.no_grad(): | |
# 獲取特定專門場景的文本特徵 | |
specialized_tokens = self.specialized_tokens_dict[scene_type] | |
text_features = self.model.encode_text(specialized_tokens) | |
text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
# 計算相似度分數 | |
similarity = (100 * image_features @ text_features.T) | |
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0] | |
# 找到最匹配的專門描述 | |
prompts = self.specialized_scene_prompts[scene_type] | |
scores = [(prompts[i], float(similarity[i])) for i in range(len(prompts))] | |
scores.sort(key=lambda x: x[1], reverse=True) | |
return { | |
"best_description": scores[0][0], | |
"confidence": scores[0][1], | |
"all_matches": scores | |
} | |
def _analyze_viewpoint(self, image_features: torch.Tensor) -> Dict[str, float]: | |
"""分析圖像的拍攝視角""" | |
with torch.no_grad(): | |
# 計算視角文本特徵 | |
text_features = self.model.encode_text(self.viewpoint_tokens) | |
text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
# 計算相似度分數 | |
similarity = (100 * image_features @ text_features.T).softmax(dim=-1) | |
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0] | |
# 建立視角分數字典 | |
viewpoint_scores = {} | |
for i, viewpoint in enumerate(self.viewpoint_prompts.keys()): | |
viewpoint_scores[viewpoint] = float(similarity[i]) | |
return viewpoint_scores | |
def _analyze_object_combinations(self, image_features: torch.Tensor) -> Dict[str, float]: | |
"""分析圖像中的物體組合""" | |
with torch.no_grad(): | |
# 計算物體組合文本特徵 | |
text_features = self.model.encode_text(self.object_combination_tokens) | |
text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
# 計算相似度分數 | |
similarity = (100 * image_features @ text_features.T).softmax(dim=-1) | |
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0] | |
# 建立物體組合分數字典 | |
combination_scores = {} | |
for i, combination in enumerate(self.object_combination_prompts.keys()): | |
combination_scores[combination] = float(similarity[i]) | |
return combination_scores | |
def _analyze_activities(self, image_features: torch.Tensor) -> Dict[str, float]: | |
"""分析圖像中的活動""" | |
with torch.no_grad(): | |
# 計算活動文本特徵 | |
text_features = self.model.encode_text(self.activity_tokens) | |
text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
# 計算相似度分數 | |
similarity = (100 * image_features @ text_features.T).softmax(dim=-1) | |
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0] | |
# 建立活動分數字典 | |
activity_scores = {} | |
for i, activity in enumerate(self.activity_prompts.keys()): | |
activity_scores[activity] = float(similarity[i]) | |
return activity_scores | |
def get_image_embedding(self, image) -> np.ndarray: | |
""" | |
獲取圖像的 CLIP 嵌入表示 | |
Args: | |
image: PIL Image 或 numpy array | |
Returns: | |
np.ndarray: 圖像的 CLIP 特徵向量 | |
""" | |
# 確保圖像是 PIL 格式 | |
if not isinstance(image, Image.Image): | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
else: | |
raise ValueError("Unsupported image format. Expected PIL Image or numpy array.") | |
# 預處理並編碼 | |
image_input = self.preprocess(image).unsqueeze(0).to(self.device) | |
with torch.no_grad(): | |
image_features = self.model.encode_image(image_input) | |
image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
# 轉換為 numpy 並返回 | |
return image_features.cpu().numpy()[0] if self.device == "cuda" else image_features.numpy()[0] | |
def text_to_embedding(self, text: str) -> np.ndarray: | |
""" | |
將文本轉換為 CLIP 嵌入表示 | |
Args: | |
text: 輸入文本 | |
Returns: | |
np.ndarray: 文本的 CLIP 特徵向量 | |
""" | |
text_token = clip.tokenize([text]).to(self.device) | |
with torch.no_grad(): | |
text_features = self.model.encode_text(text_token) | |
text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
return text_features.cpu().numpy()[0] if self.device == "cuda" else text_features.numpy()[0] | |
def calculate_similarity(self, image, text_queries: List[str]) -> Dict[str, float]: | |
""" | |
計算圖像與多個文本查詢的相似度 | |
Args: | |
image: PIL Image 或 numpy array | |
text_queries: 文本查詢列表 | |
Returns: | |
Dict: 每個查詢的相似度分數 | |
""" | |
# 獲取圖像嵌入 | |
if isinstance(image, np.ndarray) and len(image.shape) == 1: | |
# 已經是嵌入向量 | |
image_features = torch.tensor(image).unsqueeze(0).to(self.device) | |
else: | |
# 是圖像,需要提取嵌入 | |
image_features = torch.tensor(self.get_image_embedding(image)).unsqueeze(0).to(self.device) | |
# calulate similarity | |
text_tokens = clip.tokenize(text_queries).to(self.device) | |
with torch.no_grad(): | |
text_features = self.model.encode_text(text_tokens) | |
text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) | |
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0] | |
# display results | |
result = {} | |
for i, query in enumerate(text_queries): | |
result[query] = float(similarity[i]) | |
return result | |
def get_clip_instance(self): | |
""" | |
獲取初始化好的CLIP模型實例,便於其他模組重用 | |
Returns: | |
tuple: (模型實例, 預處理函數, 設備名稱) | |
""" | |
return self.model, self.preprocess, self.device | |