VisionScout / clip_analyzer.py
DawnC's picture
Update clip_analyzer.py
e3e5b96 verified
raw
history blame
32.8 kB
import torch
import clip
import numpy as np
from PIL import Image
from typing import Dict, List, Tuple, Any, Optional, Union
from clip_prompts import (
SCENE_TYPE_PROMPTS,
CULTURAL_SCENE_PROMPTS,
COMPARATIVE_PROMPTS,
LIGHTING_CONDITION_PROMPTS,
SPECIALIZED_SCENE_PROMPTS,
VIEWPOINT_PROMPTS,
OBJECT_COMBINATION_PROMPTS,
ACTIVITY_PROMPTS
)
class CLIPAnalyzer:
"""
Use Clip to intergrate scene understanding function
"""
def __init__(self, model_name: str = "ViT-B/16", device: str = None):
"""
初始化 CLIP 分析器。
Args:
model_name: CLIP Model name, 默認 "ViT-B/16"
device: Use GPU if it can use
"""
# 自動選擇設備
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
print(f"Loading CLIP model {model_name} on {self.device}...")
try:
self.model, self.preprocess = clip.load(model_name, device=self.device)
print(f"CLIP model loaded successfully.")
except Exception as e:
print(f"Error loading CLIP model: {e}")
raise
self.scene_type_prompts = SCENE_TYPE_PROMPTS
self.cultural_scene_prompts = CULTURAL_SCENE_PROMPTS
self.comparative_prompts = COMPARATIVE_PROMPTS
self.lighting_condition_prompts = LIGHTING_CONDITION_PROMPTS
self.specialized_scene_prompts = SPECIALIZED_SCENE_PROMPTS
self.viewpoint_prompts = VIEWPOINT_PROMPTS
self.object_combination_prompts = OBJECT_COMBINATION_PROMPTS
self.activity_prompts = ACTIVITY_PROMPTS
# turn to CLIP format
self._prepare_text_prompts()
def _prepare_text_prompts(self):
"""準備所有文本提示的 CLIP 特徵並存儲到 self.text_features_cache 中"""
self.text_features_cache = {}
# 處理基礎場景類型 (SCENE_TYPE_PROMPTS)
if hasattr(self, 'scene_type_prompts') and self.scene_type_prompts:
scene_texts = [prompt for scene_type, prompt in self.scene_type_prompts.items()]
if scene_texts:
self.text_features_cache["scene_type_keys"] = list(self.scene_type_prompts.keys())
try:
self.text_features_cache["scene_type_tokens"] = clip.tokenize(scene_texts).to(self.device)
except Exception as e:
print(f"Warning: Error tokenizing scene_type_prompts: {e}")
self.text_features_cache["scene_type_tokens"] = None # 標記錯誤或空
else:
self.text_features_cache["scene_type_keys"] = []
self.text_features_cache["scene_type_tokens"] = None
else:
self.text_features_cache["scene_type_keys"] = []
self.text_features_cache["scene_type_tokens"] = None
# 處理文化場景 (CULTURAL_SCENE_PROMPTS)
# cultural_tokens_dict 存儲的是 tokenized prompts
cultural_tokens_dict_val = {}
if hasattr(self, 'cultural_scene_prompts') and self.cultural_scene_prompts:
for scene_type, prompts in self.cultural_scene_prompts.items():
if prompts and isinstance(prompts, list) and all(isinstance(p, str) for p in prompts):
try:
cultural_tokens_dict_val[scene_type] = clip.tokenize(prompts).to(self.device)
except Exception as e:
print(f"Warning: Error tokenizing cultural_scene_prompts for {scene_type}: {e}")
cultural_tokens_dict_val[scene_type] = None # 標記錯誤或空
else:
cultural_tokens_dict_val[scene_type] = None # prompts 不合規
self.text_features_cache["cultural_tokens_dict"] = cultural_tokens_dict_val
# 處理光照條件 (LIGHTING_CONDITION_PROMPTS)
if hasattr(self, 'lighting_condition_prompts') and self.lighting_condition_prompts:
lighting_texts = [prompt for cond, prompt in self.lighting_condition_prompts.items()]
if lighting_texts:
self.text_features_cache["lighting_condition_keys"] = list(self.lighting_condition_prompts.keys())
try:
self.text_features_cache["lighting_tokens"] = clip.tokenize(lighting_texts).to(self.device)
except Exception as e:
print(f"Warning: Error tokenizing lighting_condition_prompts: {e}")
self.text_features_cache["lighting_tokens"] = None
else:
self.text_features_cache["lighting_condition_keys"] = []
self.text_features_cache["lighting_tokens"] = None
else:
self.text_features_cache["lighting_condition_keys"] = []
self.text_features_cache["lighting_tokens"] = None
# 處理特殊場景 (SPECIALIZED_SCENE_PROMPTS)
specialized_tokens_dict_val = {}
if hasattr(self, 'specialized_scene_prompts') and self.specialized_scene_prompts:
for scene_type, prompts in self.specialized_scene_prompts.items():
if prompts and isinstance(prompts, list) and all(isinstance(p, str) for p in prompts):
try:
specialized_tokens_dict_val[scene_type] = clip.tokenize(prompts).to(self.device)
except Exception as e:
print(f"Warning: Error tokenizing specialized_scene_prompts for {scene_type}: {e}")
specialized_tokens_dict_val[scene_type] = None
else:
specialized_tokens_dict_val[scene_type] = None
self.text_features_cache["specialized_tokens_dict"] = specialized_tokens_dict_val
# 處理視角 (VIEWPOINT_PROMPTS)
if hasattr(self, 'viewpoint_prompts') and self.viewpoint_prompts:
viewpoint_texts = [prompt for viewpoint, prompt in self.viewpoint_prompts.items()]
if viewpoint_texts:
self.text_features_cache["viewpoint_keys"] = list(self.viewpoint_prompts.keys())
try:
self.text_features_cache["viewpoint_tokens"] = clip.tokenize(viewpoint_texts).to(self.device)
except Exception as e:
print(f"Warning: Error tokenizing viewpoint_prompts: {e}")
self.text_features_cache["viewpoint_tokens"] = None
else:
self.text_features_cache["viewpoint_keys"] = []
self.text_features_cache["viewpoint_tokens"] = None
else:
self.text_features_cache["viewpoint_keys"] = []
self.text_features_cache["viewpoint_tokens"] = None
# 處理物件組合 (OBJECT_COMBINATION_PROMPTS)
if hasattr(self, 'object_combination_prompts') and self.object_combination_prompts:
object_combination_texts = [prompt for combo, prompt in self.object_combination_prompts.items()]
if object_combination_texts:
self.text_features_cache["object_combination_keys"] = list(self.object_combination_prompts.keys())
try:
self.text_features_cache["object_combination_tokens"] = clip.tokenize(object_combination_texts).to(self.device)
except Exception as e:
print(f"Warning: Error tokenizing object_combination_prompts: {e}")
self.text_features_cache["object_combination_tokens"] = None
else:
self.text_features_cache["object_combination_keys"] = []
self.text_features_cache["object_combination_tokens"] = None
else:
self.text_features_cache["object_combination_keys"] = []
self.text_features_cache["object_combination_tokens"] = None
# 處理活動 (ACTIVITY_PROMPTS)
if hasattr(self, 'activity_prompts') and self.activity_prompts:
activity_texts = [prompt for activity, prompt in self.activity_prompts.items()]
if activity_texts:
self.text_features_cache["activity_keys"] = list(self.activity_prompts.keys())
try:
self.text_features_cache["activity_tokens"] = clip.tokenize(activity_texts).to(self.device)
except Exception as e:
print(f"Warning: Error tokenizing activity_prompts: {e}")
self.text_features_cache["activity_tokens"] = None
else:
self.text_features_cache["activity_keys"] = []
self.text_features_cache["activity_tokens"] = None
else:
self.text_features_cache["activity_keys"] = []
self.text_features_cache["activity_tokens"] = None
self.scene_type_tokens = self.text_features_cache["scene_type_tokens"]
self.lighting_tokens = self.text_features_cache["lighting_tokens"]
self.viewpoint_tokens = self.text_features_cache["viewpoint_tokens"]
self.object_combination_tokens = self.text_features_cache["object_combination_tokens"]
self.activity_tokens = self.text_features_cache["activity_tokens"]
self.cultural_tokens_dict = self.text_features_cache["cultural_tokens_dict"]
self.specialized_tokens_dict = self.text_features_cache["specialized_tokens_dict"]
print("CLIP text_features_cache prepared.")
def analyze_image(self, image, include_cultural_analysis=True, exclude_categories=None, enable_landmark=True, places365_guidance=None):
"""
分析圖像,預測場景類型和光照條件。
Args:
image: 輸入圖像 (PIL Image 或 numpy array)
include_cultural_analysis: 是否包含文化場景的詳細分析
exclude_categories: 要排除的類別列表
enable_landmark: 是否啟用地標檢測功能
places365_guidance: Places365 提供的場景指導信息 (可選)
Returns:
Dict: 包含場景類型預測和光照條件的分析結果
"""
try:
self.enable_landmark = enable_landmark # 更新實例的 enable_landmark 狀態
# 確保圖像是 PIL 格式
if not isinstance(image, Image.Image):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
else:
raise ValueError("Unsupported image format. Expected PIL Image or numpy array.")
# 預處理圖像
image_input = self.preprocess(image).unsqueeze(0).to(self.device)
# 獲取圖像特徵
with torch.no_grad():
image_features = self.model.encode_image(image_input)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
places365_focus_areas = []
places365_scene_context = "" # 用於存儲 Places365 提供的場景描述
if places365_guidance and isinstance(places365_guidance, dict) and places365_guidance.get('confidence', 0) > 0.4:
mapped_scene = places365_guidance.get('mapped_scene_type', '')
scene_label = places365_guidance.get('scene_label', '')
# is_indoor = places365_guidance.get('is_indoor', None) # 未使用,可註釋
attributes = places365_guidance.get('attributes', [])
places365_scene_context = f"Scene identified by Places365 as {scene_label}" # 更新上下文描述
# Adjust CLIP analysis focus based on Places365 scene type
if mapped_scene in ['kitchen', 'dining_area', 'restaurant']:
places365_focus_areas.extend(['food preparation', 'dining setup', 'kitchen appliances'])
elif mapped_scene in ['office_workspace', 'educational_setting', 'library', 'conference_room']:
places365_focus_areas.extend(['work environment', 'professional setting', 'learning space', 'study area'])
elif mapped_scene in ['retail_store', 'shopping_mall', 'market', 'supermarket']: # 擴展匹配
places365_focus_areas.extend(['commercial space', 'shopping environment', 'retail display', 'goods for sale'])
elif mapped_scene in ['park_area', 'beach', 'natural_outdoor_area', 'playground', 'sports_field']: # 擴展匹配
places365_focus_areas.extend(['outdoor recreation', 'natural environment', 'leisure activity', 'open space'])
# 根據屬性添加更通用的 focus areas
if isinstance(attributes, list): # 確保 attributes 是列表
if 'commercial' in attributes:
places365_focus_areas.append('business activity')
if 'recreational' in attributes:
places365_focus_areas.append('entertainment or leisure')
if 'residential' in attributes:
places365_focus_areas.append('living space')
# 去重
places365_focus_areas = list(set(places365_focus_areas))
if places365_focus_areas: # 只有在確實有 focus areas 時才打印
print(f"CLIP analysis guided by Places365: {places365_scene_context}, focus areas: {places365_focus_areas}")
# 分析場景類型,傳遞 enable_landmark 參數和 Places365 指導
scene_scores = self._analyze_scene_type(image_features,
enable_landmark=self.enable_landmark, # 使用更新後的實例屬性
places365_focus=places365_focus_areas)
# 如果禁用地標功能,確保排除地標相關類別
current_exclude_categories = list(exclude_categories) if exclude_categories is not None else []
if not self.enable_landmark: # 使用更新後的實例屬性
landmark_related_terms = ["landmark", "monument", "tower", "tourist", "attraction", "historical", "famous", "iconic"]
for term in landmark_related_terms:
if term not in current_exclude_categories:
current_exclude_categories.append(term)
if current_exclude_categories:
filtered_scores = {}
for scene, score in scene_scores.items():
# 檢查 scene 的鍵名(通常是英文)是否包含任何排除詞彙
if not any(cat.lower() in scene.lower() for cat in current_exclude_categories):
filtered_scores[scene] = score
if filtered_scores:
total_score = sum(filtered_scores.values())
if total_score > 1e-5: # 避免除以零或非常小的數
scene_scores = {k: v / total_score for k, v in filtered_scores.items()}
else: # 如果總分趨近於0,則保持原樣或設為0
scene_scores = {k: 0.0 for k in filtered_scores.keys()} # 或者 scene_scores = filtered_scores
else: # 如果過濾後沒有場景了
scene_scores = {k: (0.0 if any(cat.lower() in k.lower() for cat in current_exclude_categories) else v) for k,v in scene_scores.items()}
if not any(s > 1e-5 for s in scene_scores.values()): # 如果還是全0
scene_scores = {"unknown": 1.0} # 給一個默認值避免空字典
lighting_scores = self._analyze_lighting_condition(image_features)
cultural_analysis = {}
if include_cultural_analysis and self.enable_landmark: # 使用更新後的實例屬性
for scene_type_cultural_key in self.text_features_cache.get("cultural_tokens_dict", {}).keys():
# 確保 scene_type_cultural_key 是 SCENE_TYPE_PROMPTS 中的鍵,或者有一個映射關係
if scene_type_cultural_key in scene_scores and scene_scores[scene_type_cultural_key] > 0.2:
cultural_analysis[scene_type_cultural_key] = self._analyze_cultural_scene(
image_features, scene_type_cultural_key
)
specialized_analysis = {}
for scene_type_specialized_key in self.text_features_cache.get("specialized_tokens_dict", {}).keys():
if scene_type_specialized_key in scene_scores and scene_scores[scene_type_specialized_key] > 0.2:
specialized_analysis[scene_type_specialized_key] = self._analyze_specialized_scene(
image_features, scene_type_specialized_key
)
viewpoint_scores = self._analyze_viewpoint(image_features)
object_combination_scores = self._analyze_object_combinations(image_features)
activity_scores = self._analyze_activities(image_features)
if scene_scores: # 確保 scene_scores 不是空的
top_scene = max(scene_scores.items(), key=lambda x: x[1])
# 如果禁用地標,再次確認 top_scene 不是地標相關
if not self.enable_landmark and any(cat.lower() in top_scene[0].lower() for cat in current_exclude_categories):
non_excluded_scores = {k:v for k,v in scene_scores.items() if not any(cat.lower() in k.lower() for cat in current_exclude_categories)}
if non_excluded_scores:
top_scene = max(non_excluded_scores.items(), key=lambda x: x[1])
else:
top_scene = ("unknown", 0.0) # 或其他合適的默認值
else:
top_scene = ("unknown", 0.0)
result = {
"scene_scores": scene_scores,
"top_scene": top_scene,
"lighting_condition": max(lighting_scores.items(), key=lambda x: x[1]) if lighting_scores else ("unknown", 0.0),
"embedding": image_features.cpu().numpy().tolist()[0], # 簡化
"viewpoint": max(viewpoint_scores.items(), key=lambda x: x[1]) if viewpoint_scores else ("unknown", 0.0),
"object_combinations": sorted(object_combination_scores.items(), key=lambda x: x[1], reverse=True)[:3] if object_combination_scores else [],
"activities": sorted(activity_scores.items(), key=lambda x: x[1], reverse=True)[:3] if activity_scores else []
}
if places365_guidance and isinstance(places365_guidance, dict) and places365_focus_areas: # 檢查 places365_focus_areas 是否被填充
result["places365_guidance"] = {
"scene_context": places365_scene_context,
"focus_areas": places365_focus_areas, # 現在這個會包含基於 guidance 的內容
"guided_analysis": True,
"original_places365_scene": places365_guidance.get('scene_label', 'N/A'),
"original_places365_confidence": places365_guidance.get('confidence', 0.0)
}
if cultural_analysis and self.enable_landmark:
result["cultural_analysis"] = cultural_analysis
if specialized_analysis:
result["specialized_analysis"] = specialized_analysis
return result
except Exception as e:
print(f"Error analyzing image with CLIP: {e}")
import traceback
traceback.print_exc()
return {"error": str(e), "scene_scores": {}, "top_scene": ("error", 0.0)}
def _analyze_scene_type(self, image_features: torch.Tensor, enable_landmark: bool = True, places365_focus: List[str] = None) -> Dict[str, float]:
"""
分析圖像特徵與各場景類型的相似度,並可選擇性地排除地標相關場景
Args:
image_features: 經過 CLIP 編碼的圖像特徵
enable_landmark: 是否啟用地標識別功能
Returns:
Dict[str, float]: 各場景類型的相似度分數字典
"""
with torch.no_grad():
# 計算場景類型文本特徵
text_features = self.model.encode_text(self.scene_type_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# Apply Places365 guidance if available
if places365_focus and len(places365_focus) > 0:
# Create enhanced prompts that incorporate Places365 guidance
enhanced_prompts = []
for scene_type in self.scene_type_prompts.keys():
base_prompt = self.scene_type_prompts[scene_type]
# Check if this scene type should be emphasized based on Places365 guidance
scene_lower = scene_type.lower()
should_enhance = False
for focus_area in places365_focus:
if any(keyword in scene_lower for keyword in focus_area.split()):
should_enhance = True
enhanced_prompts.append(f"{base_prompt} with {focus_area}")
break
if not should_enhance:
enhanced_prompts.append(base_prompt)
# Re-tokenize and encode enhanced prompts
enhanced_tokens = clip.tokenize(enhanced_prompts).to(self.device)
text_features = self.model.encode_text(enhanced_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T).softmax(dim=-1)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 建立場景分數字典
scene_scores = {}
for i, scene_type in enumerate(self.scene_type_prompts.keys()):
# 如果未啟用地標功能,則跳過地標相關場景類型
if not enable_landmark and scene_type in ["tourist_landmark", "natural_landmark", "historical_monument"]:
scene_scores[scene_type] = 0.0 # 將地標場景分數設為零
else:
base_score = float(similarity[i])
# Apply Places365 guidance boost if applicable
if places365_focus:
scene_lower = scene_type.lower()
boost_factor = 1.0
for focus_area in places365_focus:
if any(keyword in scene_lower for keyword in focus_area.split()):
boost_factor = 1.15 # 15% boost for matching scenes
break
scene_scores[scene_type] = base_score * boost_factor
else:
scene_scores[scene_type] = base_score
# 如果禁用地標功能,確保重新歸一化剩餘場景分數
if not enable_landmark:
# 獲取所有非零分數
non_zero_scores = {k: v for k, v in scene_scores.items() if v > 0}
if non_zero_scores:
# 計算總和並歸一化
total_score = sum(non_zero_scores.values())
if total_score > 0:
for scene_type in non_zero_scores:
scene_scores[scene_type] = non_zero_scores[scene_type] / total_score
return scene_scores
def _analyze_lighting_condition(self, image_features: torch.Tensor) -> Dict[str, float]:
"""分析圖像的光照條件"""
with torch.no_grad():
# 計算光照條件文本特徵
text_features = self.model.encode_text(self.lighting_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T).softmax(dim=-1)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 建立光照條件分數字典
lighting_scores = {}
for i, lighting_type in enumerate(self.lighting_condition_prompts.keys()):
lighting_scores[lighting_type] = float(similarity[i])
return lighting_scores
def _analyze_cultural_scene(self, image_features: torch.Tensor, scene_type: str) -> Dict[str, Any]:
"""針對特定文化場景進行深入分析"""
if scene_type not in self.cultural_tokens_dict:
return {"error": f"No cultural analysis available for {scene_type}"}
with torch.no_grad():
# 獲取特定文化場景的文本特徵
cultural_tokens = self.cultural_tokens_dict[scene_type]
text_features = self.model.encode_text(cultural_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 找到最匹配的文化描述
prompts = self.cultural_scene_prompts[scene_type]
scores = [(prompts[i], float(similarity[i])) for i in range(len(prompts))]
scores.sort(key=lambda x: x[1], reverse=True)
return {
"best_description": scores[0][0],
"confidence": scores[0][1],
"all_matches": scores
}
def _analyze_specialized_scene(self, image_features: torch.Tensor, scene_type: str) -> Dict[str, Any]:
"""針對特定專門場景進行深入分析"""
if scene_type not in self.specialized_tokens_dict:
return {"error": f"No specialized analysis available for {scene_type}"}
with torch.no_grad():
# 獲取特定專門場景的文本特徵
specialized_tokens = self.specialized_tokens_dict[scene_type]
text_features = self.model.encode_text(specialized_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 找到最匹配的專門描述
prompts = self.specialized_scene_prompts[scene_type]
scores = [(prompts[i], float(similarity[i])) for i in range(len(prompts))]
scores.sort(key=lambda x: x[1], reverse=True)
return {
"best_description": scores[0][0],
"confidence": scores[0][1],
"all_matches": scores
}
def _analyze_viewpoint(self, image_features: torch.Tensor) -> Dict[str, float]:
"""分析圖像的拍攝視角"""
with torch.no_grad():
# 計算視角文本特徵
text_features = self.model.encode_text(self.viewpoint_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T).softmax(dim=-1)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 建立視角分數字典
viewpoint_scores = {}
for i, viewpoint in enumerate(self.viewpoint_prompts.keys()):
viewpoint_scores[viewpoint] = float(similarity[i])
return viewpoint_scores
def _analyze_object_combinations(self, image_features: torch.Tensor) -> Dict[str, float]:
"""分析圖像中的物體組合"""
with torch.no_grad():
# 計算物體組合文本特徵
text_features = self.model.encode_text(self.object_combination_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T).softmax(dim=-1)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 建立物體組合分數字典
combination_scores = {}
for i, combination in enumerate(self.object_combination_prompts.keys()):
combination_scores[combination] = float(similarity[i])
return combination_scores
def _analyze_activities(self, image_features: torch.Tensor) -> Dict[str, float]:
"""分析圖像中的活動"""
with torch.no_grad():
# 計算活動文本特徵
text_features = self.model.encode_text(self.activity_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T).softmax(dim=-1)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 建立活動分數字典
activity_scores = {}
for i, activity in enumerate(self.activity_prompts.keys()):
activity_scores[activity] = float(similarity[i])
return activity_scores
def get_image_embedding(self, image) -> np.ndarray:
"""
獲取圖像的 CLIP 嵌入表示
Args:
image: PIL Image 或 numpy array
Returns:
np.ndarray: 圖像的 CLIP 特徵向量
"""
# 確保圖像是 PIL 格式
if not isinstance(image, Image.Image):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
else:
raise ValueError("Unsupported image format. Expected PIL Image or numpy array.")
# 預處理並編碼
image_input = self.preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad():
image_features = self.model.encode_image(image_input)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
# 轉換為 numpy 並返回
return image_features.cpu().numpy()[0] if self.device == "cuda" else image_features.numpy()[0]
def text_to_embedding(self, text: str) -> np.ndarray:
"""
將文本轉換為 CLIP 嵌入表示
Args:
text: 輸入文本
Returns:
np.ndarray: 文本的 CLIP 特徵向量
"""
text_token = clip.tokenize([text]).to(self.device)
with torch.no_grad():
text_features = self.model.encode_text(text_token)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
return text_features.cpu().numpy()[0] if self.device == "cuda" else text_features.numpy()[0]
def calculate_similarity(self, image, text_queries: List[str]) -> Dict[str, float]:
"""
計算圖像與多個文本查詢的相似度
Args:
image: PIL Image 或 numpy array
text_queries: 文本查詢列表
Returns:
Dict: 每個查詢的相似度分數
"""
# 獲取圖像嵌入
if isinstance(image, np.ndarray) and len(image.shape) == 1:
# 已經是嵌入向量
image_features = torch.tensor(image).unsqueeze(0).to(self.device)
else:
# 是圖像,需要提取嵌入
image_features = torch.tensor(self.get_image_embedding(image)).unsqueeze(0).to(self.device)
# calulate similarity
text_tokens = clip.tokenize(text_queries).to(self.device)
with torch.no_grad():
text_features = self.model.encode_text(text_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# display results
result = {}
for i, query in enumerate(text_queries):
result[query] = float(similarity[i])
return result
def get_clip_instance(self):
"""
獲取初始化好的CLIP模型實例,便於其他模組重用
Returns:
tuple: (模型實例, 預處理函數, 設備名稱)
"""
return self.model, self.preprocess, self.device