VisionScout / clip_analyzer.py
DawnC's picture
Upload 31 files
4d1f920 verified
raw
history blame
32.8 kB
import torch
import clip
import numpy as np
from PIL import Image
from typing import Dict, List, Tuple, Any, Optional, Union
from clip_prompts import (
SCENE_TYPE_PROMPTS,
CULTURAL_SCENE_PROMPTS,
COMPARATIVE_PROMPTS,
LIGHTING_CONDITION_PROMPTS,
SPECIALIZED_SCENE_PROMPTS,
VIEWPOINT_PROMPTS,
OBJECT_COMBINATION_PROMPTS,
ACTIVITY_PROMPTS
)
class CLIPAnalyzer:
"""
Use Clip to intergrate scene understanding function
"""
def __init__(self, model_name: str = "ViT-L/14", device: str = None):
"""
初始化 CLIP 分析器。
Args:
model_name: CLIP Model name, 默認 "ViT-L/14"
device: Use GPU if it can use
"""
# 自動選擇設備
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
print(f"Loading CLIP model {model_name} on {self.device}...")
try:
self.model, self.preprocess = clip.load(model_name, device=self.device)
print(f"CLIP model loaded successfully.")
except Exception as e:
print(f"Error loading CLIP model: {e}")
raise
self.scene_type_prompts = SCENE_TYPE_PROMPTS
self.cultural_scene_prompts = CULTURAL_SCENE_PROMPTS
self.comparative_prompts = COMPARATIVE_PROMPTS
self.lighting_condition_prompts = LIGHTING_CONDITION_PROMPTS
self.specialized_scene_prompts = SPECIALIZED_SCENE_PROMPTS
self.viewpoint_prompts = VIEWPOINT_PROMPTS
self.object_combination_prompts = OBJECT_COMBINATION_PROMPTS
self.activity_prompts = ACTIVITY_PROMPTS
# turn to CLIP format
self._prepare_text_prompts()
def _prepare_text_prompts(self):
"""準備所有文本提示的 CLIP 特徵並存儲到 self.text_features_cache 中"""
self.text_features_cache = {}
# 處理基礎場景類型 (SCENE_TYPE_PROMPTS)
if hasattr(self, 'scene_type_prompts') and self.scene_type_prompts:
scene_texts = [prompt for scene_type, prompt in self.scene_type_prompts.items()]
if scene_texts:
self.text_features_cache["scene_type_keys"] = list(self.scene_type_prompts.keys())
try:
self.text_features_cache["scene_type_tokens"] = clip.tokenize(scene_texts).to(self.device)
except Exception as e:
print(f"Warning: Error tokenizing scene_type_prompts: {e}")
self.text_features_cache["scene_type_tokens"] = None # 標記錯誤或空
else:
self.text_features_cache["scene_type_keys"] = []
self.text_features_cache["scene_type_tokens"] = None
else:
self.text_features_cache["scene_type_keys"] = []
self.text_features_cache["scene_type_tokens"] = None
# 處理文化場景 (CULTURAL_SCENE_PROMPTS)
# cultural_tokens_dict 存儲的是 tokenized prompts
cultural_tokens_dict_val = {}
if hasattr(self, 'cultural_scene_prompts') and self.cultural_scene_prompts:
for scene_type, prompts in self.cultural_scene_prompts.items():
if prompts and isinstance(prompts, list) and all(isinstance(p, str) for p in prompts):
try:
cultural_tokens_dict_val[scene_type] = clip.tokenize(prompts).to(self.device)
except Exception as e:
print(f"Warning: Error tokenizing cultural_scene_prompts for {scene_type}: {e}")
cultural_tokens_dict_val[scene_type] = None # 標記錯誤或空
else:
cultural_tokens_dict_val[scene_type] = None # prompts 不合規
self.text_features_cache["cultural_tokens_dict"] = cultural_tokens_dict_val
# 處理光照條件 (LIGHTING_CONDITION_PROMPTS)
if hasattr(self, 'lighting_condition_prompts') and self.lighting_condition_prompts:
lighting_texts = [prompt for cond, prompt in self.lighting_condition_prompts.items()]
if lighting_texts:
self.text_features_cache["lighting_condition_keys"] = list(self.lighting_condition_prompts.keys())
try:
self.text_features_cache["lighting_tokens"] = clip.tokenize(lighting_texts).to(self.device)
except Exception as e:
print(f"Warning: Error tokenizing lighting_condition_prompts: {e}")
self.text_features_cache["lighting_tokens"] = None
else:
self.text_features_cache["lighting_condition_keys"] = []
self.text_features_cache["lighting_tokens"] = None
else:
self.text_features_cache["lighting_condition_keys"] = []
self.text_features_cache["lighting_tokens"] = None
# 處理特殊場景 (SPECIALIZED_SCENE_PROMPTS)
specialized_tokens_dict_val = {}
if hasattr(self, 'specialized_scene_prompts') and self.specialized_scene_prompts:
for scene_type, prompts in self.specialized_scene_prompts.items():
if prompts and isinstance(prompts, list) and all(isinstance(p, str) for p in prompts):
try:
specialized_tokens_dict_val[scene_type] = clip.tokenize(prompts).to(self.device)
except Exception as e:
print(f"Warning: Error tokenizing specialized_scene_prompts for {scene_type}: {e}")
specialized_tokens_dict_val[scene_type] = None
else:
specialized_tokens_dict_val[scene_type] = None
self.text_features_cache["specialized_tokens_dict"] = specialized_tokens_dict_val
# 處理視角 (VIEWPOINT_PROMPTS)
if hasattr(self, 'viewpoint_prompts') and self.viewpoint_prompts:
viewpoint_texts = [prompt for viewpoint, prompt in self.viewpoint_prompts.items()]
if viewpoint_texts:
self.text_features_cache["viewpoint_keys"] = list(self.viewpoint_prompts.keys())
try:
self.text_features_cache["viewpoint_tokens"] = clip.tokenize(viewpoint_texts).to(self.device)
except Exception as e:
print(f"Warning: Error tokenizing viewpoint_prompts: {e}")
self.text_features_cache["viewpoint_tokens"] = None
else:
self.text_features_cache["viewpoint_keys"] = []
self.text_features_cache["viewpoint_tokens"] = None
else:
self.text_features_cache["viewpoint_keys"] = []
self.text_features_cache["viewpoint_tokens"] = None
# 處理物件組合 (OBJECT_COMBINATION_PROMPTS)
if hasattr(self, 'object_combination_prompts') and self.object_combination_prompts:
object_combination_texts = [prompt for combo, prompt in self.object_combination_prompts.items()]
if object_combination_texts:
self.text_features_cache["object_combination_keys"] = list(self.object_combination_prompts.keys())
try:
self.text_features_cache["object_combination_tokens"] = clip.tokenize(object_combination_texts).to(self.device)
except Exception as e:
print(f"Warning: Error tokenizing object_combination_prompts: {e}")
self.text_features_cache["object_combination_tokens"] = None
else:
self.text_features_cache["object_combination_keys"] = []
self.text_features_cache["object_combination_tokens"] = None
else:
self.text_features_cache["object_combination_keys"] = []
self.text_features_cache["object_combination_tokens"] = None
# 處理活動 (ACTIVITY_PROMPTS)
if hasattr(self, 'activity_prompts') and self.activity_prompts:
activity_texts = [prompt for activity, prompt in self.activity_prompts.items()]
if activity_texts:
self.text_features_cache["activity_keys"] = list(self.activity_prompts.keys())
try:
self.text_features_cache["activity_tokens"] = clip.tokenize(activity_texts).to(self.device)
except Exception as e:
print(f"Warning: Error tokenizing activity_prompts: {e}")
self.text_features_cache["activity_tokens"] = None
else:
self.text_features_cache["activity_keys"] = []
self.text_features_cache["activity_tokens"] = None
else:
self.text_features_cache["activity_keys"] = []
self.text_features_cache["activity_tokens"] = None
self.scene_type_tokens = self.text_features_cache["scene_type_tokens"]
self.lighting_tokens = self.text_features_cache["lighting_tokens"]
self.viewpoint_tokens = self.text_features_cache["viewpoint_tokens"]
self.object_combination_tokens = self.text_features_cache["object_combination_tokens"]
self.activity_tokens = self.text_features_cache["activity_tokens"]
self.cultural_tokens_dict = self.text_features_cache["cultural_tokens_dict"]
self.specialized_tokens_dict = self.text_features_cache["specialized_tokens_dict"]
print("CLIP text_features_cache prepared.")
def analyze_image(self, image, include_cultural_analysis=True, exclude_categories=None, enable_landmark=True, places365_guidance=None):
"""
分析圖像,預測場景類型和光照條件。
Args:
image: 輸入圖像 (PIL Image 或 numpy array)
include_cultural_analysis: 是否包含文化場景的詳細分析
exclude_categories: 要排除的類別列表
enable_landmark: 是否啟用地標檢測功能
places365_guidance: Places365 提供的場景指導信息 (可選)
Returns:
Dict: 包含場景類型預測和光照條件的分析結果
"""
try:
self.enable_landmark = enable_landmark # 更新實例的 enable_landmark 狀態
# 確保圖像是 PIL 格式
if not isinstance(image, Image.Image):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
else:
raise ValueError("Unsupported image format. Expected PIL Image or numpy array.")
# 預處理圖像
image_input = self.preprocess(image).unsqueeze(0).to(self.device)
# 獲取圖像特徵
with torch.no_grad():
image_features = self.model.encode_image(image_input)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
places365_focus_areas = []
places365_scene_context = "" # 用於存儲 Places365 提供的場景描述
if places365_guidance and isinstance(places365_guidance, dict) and places365_guidance.get('confidence', 0) > 0.4:
mapped_scene = places365_guidance.get('mapped_scene_type', '')
scene_label = places365_guidance.get('scene_label', '')
# is_indoor = places365_guidance.get('is_indoor', None) # 未使用,可註釋
attributes = places365_guidance.get('attributes', [])
places365_scene_context = f"Scene identified by Places365 as {scene_label}" # 更新上下文描述
# Adjust CLIP analysis focus based on Places365 scene type
if mapped_scene in ['kitchen', 'dining_area', 'restaurant']:
places365_focus_areas.extend(['food preparation', 'dining setup', 'kitchen appliances'])
elif mapped_scene in ['office_workspace', 'educational_setting', 'library', 'conference_room']:
places365_focus_areas.extend(['work environment', 'professional setting', 'learning space', 'study area'])
elif mapped_scene in ['retail_store', 'shopping_mall', 'market', 'supermarket']: # 擴展匹配
places365_focus_areas.extend(['commercial space', 'shopping environment', 'retail display', 'goods for sale'])
elif mapped_scene in ['park_area', 'beach', 'natural_outdoor_area', 'playground', 'sports_field']: # 擴展匹配
places365_focus_areas.extend(['outdoor recreation', 'natural environment', 'leisure activity', 'open space'])
# 根據屬性添加更通用的 focus areas
if isinstance(attributes, list): # 確保 attributes 是列表
if 'commercial' in attributes:
places365_focus_areas.append('business activity')
if 'recreational' in attributes:
places365_focus_areas.append('entertainment or leisure')
if 'residential' in attributes:
places365_focus_areas.append('living space')
# 去重
places365_focus_areas = list(set(places365_focus_areas))
if places365_focus_areas: # 只有在確實有 focus areas 時才打印
print(f"CLIP analysis guided by Places365: {places365_scene_context}, focus areas: {places365_focus_areas}")
# 分析場景類型,傳遞 enable_landmark 參數和 Places365 指導
scene_scores = self._analyze_scene_type(image_features,
enable_landmark=self.enable_landmark, # 使用更新後的實例屬性
places365_focus=places365_focus_areas)
# 如果禁用地標功能,確保排除地標相關類別
current_exclude_categories = list(exclude_categories) if exclude_categories is not None else []
if not self.enable_landmark: # 使用更新後的實例屬性
landmark_related_terms = ["landmark", "monument", "tower", "tourist", "attraction", "historical", "famous", "iconic"]
for term in landmark_related_terms:
if term not in current_exclude_categories:
current_exclude_categories.append(term)
if current_exclude_categories:
filtered_scores = {}
for scene, score in scene_scores.items():
# 檢查 scene 的鍵名(通常是英文)是否包含任何排除詞彙
if not any(cat.lower() in scene.lower() for cat in current_exclude_categories):
filtered_scores[scene] = score
if filtered_scores:
total_score = sum(filtered_scores.values())
if total_score > 1e-5: # 避免除以零或非常小的數
scene_scores = {k: v / total_score for k, v in filtered_scores.items()}
else: # 如果總分趨近於0,則保持原樣或設為0
scene_scores = {k: 0.0 for k in filtered_scores.keys()} # 或者 scene_scores = filtered_scores
else: # 如果過濾後沒有場景了
scene_scores = {k: (0.0 if any(cat.lower() in k.lower() for cat in current_exclude_categories) else v) for k,v in scene_scores.items()}
if not any(s > 1e-5 for s in scene_scores.values()): # 如果還是全0
scene_scores = {"unknown": 1.0} # 給一個默認值避免空字典
lighting_scores = self._analyze_lighting_condition(image_features)
cultural_analysis = {}
if include_cultural_analysis and self.enable_landmark: # 使用更新後的實例屬性
for scene_type_cultural_key in self.text_features_cache.get("cultural_tokens_dict", {}).keys():
# 確保 scene_type_cultural_key 是 SCENE_TYPE_PROMPTS 中的鍵,或者有一個映射關係
if scene_type_cultural_key in scene_scores and scene_scores[scene_type_cultural_key] > 0.2:
cultural_analysis[scene_type_cultural_key] = self._analyze_cultural_scene(
image_features, scene_type_cultural_key
)
specialized_analysis = {}
for scene_type_specialized_key in self.text_features_cache.get("specialized_tokens_dict", {}).keys():
if scene_type_specialized_key in scene_scores and scene_scores[scene_type_specialized_key] > 0.2:
specialized_analysis[scene_type_specialized_key] = self._analyze_specialized_scene(
image_features, scene_type_specialized_key
)
viewpoint_scores = self._analyze_viewpoint(image_features)
object_combination_scores = self._analyze_object_combinations(image_features)
activity_scores = self._analyze_activities(image_features)
if scene_scores: # 確保 scene_scores 不是空的
top_scene = max(scene_scores.items(), key=lambda x: x[1])
# 如果禁用地標,再次確認 top_scene 不是地標相關
if not self.enable_landmark and any(cat.lower() in top_scene[0].lower() for cat in current_exclude_categories):
non_excluded_scores = {k:v for k,v in scene_scores.items() if not any(cat.lower() in k.lower() for cat in current_exclude_categories)}
if non_excluded_scores:
top_scene = max(non_excluded_scores.items(), key=lambda x: x[1])
else:
top_scene = ("unknown", 0.0) # 或其他合適的默認值
else:
top_scene = ("unknown", 0.0)
result = {
"scene_scores": scene_scores,
"top_scene": top_scene,
"lighting_condition": max(lighting_scores.items(), key=lambda x: x[1]) if lighting_scores else ("unknown", 0.0),
"embedding": image_features.cpu().numpy().tolist()[0], # 簡化
"viewpoint": max(viewpoint_scores.items(), key=lambda x: x[1]) if viewpoint_scores else ("unknown", 0.0),
"object_combinations": sorted(object_combination_scores.items(), key=lambda x: x[1], reverse=True)[:3] if object_combination_scores else [],
"activities": sorted(activity_scores.items(), key=lambda x: x[1], reverse=True)[:3] if activity_scores else []
}
if places365_guidance and isinstance(places365_guidance, dict) and places365_focus_areas: # 檢查 places365_focus_areas 是否被填充
result["places365_guidance"] = {
"scene_context": places365_scene_context,
"focus_areas": places365_focus_areas, # 現在這個會包含基於 guidance 的內容
"guided_analysis": True,
"original_places365_scene": places365_guidance.get('scene_label', 'N/A'),
"original_places365_confidence": places365_guidance.get('confidence', 0.0)
}
if cultural_analysis and self.enable_landmark:
result["cultural_analysis"] = cultural_analysis
if specialized_analysis:
result["specialized_analysis"] = specialized_analysis
return result
except Exception as e:
print(f"Error analyzing image with CLIP: {e}")
import traceback
traceback.print_exc()
return {"error": str(e), "scene_scores": {}, "top_scene": ("error", 0.0)}
def _analyze_scene_type(self, image_features: torch.Tensor, enable_landmark: bool = True, places365_focus: List[str] = None) -> Dict[str, float]:
"""
分析圖像特徵與各場景類型的相似度,並可選擇性地排除地標相關場景
Args:
image_features: 經過 CLIP 編碼的圖像特徵
enable_landmark: 是否啟用地標識別功能
Returns:
Dict[str, float]: 各場景類型的相似度分數字典
"""
with torch.no_grad():
# 計算場景類型文本特徵
text_features = self.model.encode_text(self.scene_type_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# Apply Places365 guidance if available
if places365_focus and len(places365_focus) > 0:
# Create enhanced prompts that incorporate Places365 guidance
enhanced_prompts = []
for scene_type in self.scene_type_prompts.keys():
base_prompt = self.scene_type_prompts[scene_type]
# Check if this scene type should be emphasized based on Places365 guidance
scene_lower = scene_type.lower()
should_enhance = False
for focus_area in places365_focus:
if any(keyword in scene_lower for keyword in focus_area.split()):
should_enhance = True
enhanced_prompts.append(f"{base_prompt} with {focus_area}")
break
if not should_enhance:
enhanced_prompts.append(base_prompt)
# Re-tokenize and encode enhanced prompts
enhanced_tokens = clip.tokenize(enhanced_prompts).to(self.device)
text_features = self.model.encode_text(enhanced_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T).softmax(dim=-1)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 建立場景分數字典
scene_scores = {}
for i, scene_type in enumerate(self.scene_type_prompts.keys()):
# 如果未啟用地標功能,則跳過地標相關場景類型
if not enable_landmark and scene_type in ["tourist_landmark", "natural_landmark", "historical_monument"]:
scene_scores[scene_type] = 0.0 # 將地標場景分數設為零
else:
base_score = float(similarity[i])
# Apply Places365 guidance boost if applicable
if places365_focus:
scene_lower = scene_type.lower()
boost_factor = 1.0
for focus_area in places365_focus:
if any(keyword in scene_lower for keyword in focus_area.split()):
boost_factor = 1.15 # 15% boost for matching scenes
break
scene_scores[scene_type] = base_score * boost_factor
else:
scene_scores[scene_type] = base_score
# 如果禁用地標功能,確保重新歸一化剩餘場景分數
if not enable_landmark:
# 獲取所有非零分數
non_zero_scores = {k: v for k, v in scene_scores.items() if v > 0}
if non_zero_scores:
# 計算總和並歸一化
total_score = sum(non_zero_scores.values())
if total_score > 0:
for scene_type in non_zero_scores:
scene_scores[scene_type] = non_zero_scores[scene_type] / total_score
return scene_scores
def _analyze_lighting_condition(self, image_features: torch.Tensor) -> Dict[str, float]:
"""分析圖像的光照條件"""
with torch.no_grad():
# 計算光照條件文本特徵
text_features = self.model.encode_text(self.lighting_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T).softmax(dim=-1)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 建立光照條件分數字典
lighting_scores = {}
for i, lighting_type in enumerate(self.lighting_condition_prompts.keys()):
lighting_scores[lighting_type] = float(similarity[i])
return lighting_scores
def _analyze_cultural_scene(self, image_features: torch.Tensor, scene_type: str) -> Dict[str, Any]:
"""針對特定文化場景進行深入分析"""
if scene_type not in self.cultural_tokens_dict:
return {"error": f"No cultural analysis available for {scene_type}"}
with torch.no_grad():
# 獲取特定文化場景的文本特徵
cultural_tokens = self.cultural_tokens_dict[scene_type]
text_features = self.model.encode_text(cultural_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 找到最匹配的文化描述
prompts = self.cultural_scene_prompts[scene_type]
scores = [(prompts[i], float(similarity[i])) for i in range(len(prompts))]
scores.sort(key=lambda x: x[1], reverse=True)
return {
"best_description": scores[0][0],
"confidence": scores[0][1],
"all_matches": scores
}
def _analyze_specialized_scene(self, image_features: torch.Tensor, scene_type: str) -> Dict[str, Any]:
"""針對特定專門場景進行深入分析"""
if scene_type not in self.specialized_tokens_dict:
return {"error": f"No specialized analysis available for {scene_type}"}
with torch.no_grad():
# 獲取特定專門場景的文本特徵
specialized_tokens = self.specialized_tokens_dict[scene_type]
text_features = self.model.encode_text(specialized_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 找到最匹配的專門描述
prompts = self.specialized_scene_prompts[scene_type]
scores = [(prompts[i], float(similarity[i])) for i in range(len(prompts))]
scores.sort(key=lambda x: x[1], reverse=True)
return {
"best_description": scores[0][0],
"confidence": scores[0][1],
"all_matches": scores
}
def _analyze_viewpoint(self, image_features: torch.Tensor) -> Dict[str, float]:
"""分析圖像的拍攝視角"""
with torch.no_grad():
# 計算視角文本特徵
text_features = self.model.encode_text(self.viewpoint_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T).softmax(dim=-1)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 建立視角分數字典
viewpoint_scores = {}
for i, viewpoint in enumerate(self.viewpoint_prompts.keys()):
viewpoint_scores[viewpoint] = float(similarity[i])
return viewpoint_scores
def _analyze_object_combinations(self, image_features: torch.Tensor) -> Dict[str, float]:
"""分析圖像中的物體組合"""
with torch.no_grad():
# 計算物體組合文本特徵
text_features = self.model.encode_text(self.object_combination_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T).softmax(dim=-1)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 建立物體組合分數字典
combination_scores = {}
for i, combination in enumerate(self.object_combination_prompts.keys()):
combination_scores[combination] = float(similarity[i])
return combination_scores
def _analyze_activities(self, image_features: torch.Tensor) -> Dict[str, float]:
"""分析圖像中的活動"""
with torch.no_grad():
# 計算活動文本特徵
text_features = self.model.encode_text(self.activity_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T).softmax(dim=-1)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 建立活動分數字典
activity_scores = {}
for i, activity in enumerate(self.activity_prompts.keys()):
activity_scores[activity] = float(similarity[i])
return activity_scores
def get_image_embedding(self, image) -> np.ndarray:
"""
獲取圖像的 CLIP 嵌入表示
Args:
image: PIL Image 或 numpy array
Returns:
np.ndarray: 圖像的 CLIP 特徵向量
"""
# 確保圖像是 PIL 格式
if not isinstance(image, Image.Image):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
else:
raise ValueError("Unsupported image format. Expected PIL Image or numpy array.")
# 預處理並編碼
image_input = self.preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad():
image_features = self.model.encode_image(image_input)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
# 轉換為 numpy 並返回
return image_features.cpu().numpy()[0] if self.device == "cuda" else image_features.numpy()[0]
def text_to_embedding(self, text: str) -> np.ndarray:
"""
將文本轉換為 CLIP 嵌入表示
Args:
text: 輸入文本
Returns:
np.ndarray: 文本的 CLIP 特徵向量
"""
text_token = clip.tokenize([text]).to(self.device)
with torch.no_grad():
text_features = self.model.encode_text(text_token)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
return text_features.cpu().numpy()[0] if self.device == "cuda" else text_features.numpy()[0]
def calculate_similarity(self, image, text_queries: List[str]) -> Dict[str, float]:
"""
計算圖像與多個文本查詢的相似度
Args:
image: PIL Image 或 numpy array
text_queries: 文本查詢列表
Returns:
Dict: 每個查詢的相似度分數
"""
# 獲取圖像嵌入
if isinstance(image, np.ndarray) and len(image.shape) == 1:
# 已經是嵌入向量
image_features = torch.tensor(image).unsqueeze(0).to(self.device)
else:
# 是圖像,需要提取嵌入
image_features = torch.tensor(self.get_image_embedding(image)).unsqueeze(0).to(self.device)
# calulate similarity
text_tokens = clip.tokenize(text_queries).to(self.device)
with torch.no_grad():
text_features = self.model.encode_text(text_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# display results
result = {}
for i, query in enumerate(text_queries):
result[query] = float(similarity[i])
return result
def get_clip_instance(self):
"""
獲取初始化好的CLIP模型實例,便於其他模組重用
Returns:
tuple: (模型實例, 預處理函數, 設備名稱)
"""
return self.model, self.preprocess, self.device