VisionScout / clip_analyzer.py
DawnC's picture
Upload 27 files
3172319 verified
import torch
import clip
import numpy as np
from PIL import Image
from typing import Dict, List, Tuple, Any, Optional, Union
from clip_prompts import (
SCENE_TYPE_PROMPTS,
CULTURAL_SCENE_PROMPTS,
COMPARATIVE_PROMPTS,
LIGHTING_CONDITION_PROMPTS,
SPECIALIZED_SCENE_PROMPTS,
VIEWPOINT_PROMPTS,
OBJECT_COMBINATION_PROMPTS,
ACTIVITY_PROMPTS
)
class CLIPAnalyzer:
"""
Use Clip to intergrate scene understanding function
"""
def __init__(self, model_name: str = "ViT-B/32", device: str = None):
"""
初始化 CLIP 分析器。
Args:
model_name: CLIP Model name, "ViT-B/32"、"ViT-B/16"、"ViT-L/14"
device: Use GPU if it can use
"""
# 自動選擇設備
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
print(f"Loading CLIP model {model_name} on {self.device}...")
try:
self.model, self.preprocess = clip.load(model_name, device=self.device)
print(f"CLIP model loaded successfully.")
except Exception as e:
print(f"Error loading CLIP model: {e}")
raise
self.scene_type_prompts = SCENE_TYPE_PROMPTS
self.cultural_scene_prompts = CULTURAL_SCENE_PROMPTS
self.comparative_prompts = COMPARATIVE_PROMPTS
self.lighting_condition_prompts = LIGHTING_CONDITION_PROMPTS
self.specialized_scene_prompts = SPECIALIZED_SCENE_PROMPTS
self.viewpoint_prompts = VIEWPOINT_PROMPTS
self.object_combination_prompts = OBJECT_COMBINATION_PROMPTS
self.activity_prompts = ACTIVITY_PROMPTS
# turn to CLIP format
self._prepare_text_prompts()
def _prepare_text_prompts(self):
"""準備所有文本提示的 CLIP 特徵"""
# base prompt
scene_texts = [self.scene_type_prompts[scene_type] for scene_type in self.scene_type_prompts]
self.scene_type_tokens = clip.tokenize(scene_texts).to(self.device)
# cultural
self.cultural_tokens_dict = {}
for scene_type, prompts in self.cultural_scene_prompts.items():
self.cultural_tokens_dict[scene_type] = clip.tokenize(prompts).to(self.device)
# Light
lighting_texts = [self.lighting_condition_prompts[cond] for cond in self.lighting_condition_prompts]
self.lighting_tokens = clip.tokenize(lighting_texts).to(self.device)
# specializes_status
self.specialized_tokens_dict = {}
for scene_type, prompts in self.specialized_scene_prompts.items():
self.specialized_tokens_dict[scene_type] = clip.tokenize(prompts).to(self.device)
# view point
viewpoint_texts = [self.viewpoint_prompts[viewpoint] for viewpoint in self.viewpoint_prompts]
self.viewpoint_tokens = clip.tokenize(viewpoint_texts).to(self.device)
# object combination
object_combination_texts = [self.object_combination_prompts[combo] for combo in self.object_combination_prompts]
self.object_combination_tokens = clip.tokenize(object_combination_texts).to(self.device)
# activicty prompt
activity_texts = [self.activity_prompts[activity] for activity in self.activity_prompts]
self.activity_tokens = clip.tokenize(activity_texts).to(self.device)
def analyze_image(self, image, include_cultural_analysis: bool = True) -> Dict[str, Any]:
"""
分析圖像,預測場景類型和光照條件。
Args:
image: 輸入圖像 (PIL Image 或 numpy array)
include_cultural_analysis: 是否包含文化場景的詳細分析
Returns:
Dict: 包含場景類型預測和光照條件的分析結果
"""
try:
# 確保圖像是 PIL 格式
if not isinstance(image, Image.Image):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
else:
raise ValueError("Unsupported image format. Expected PIL Image or numpy array.")
# 預處理圖像
image_input = self.preprocess(image).unsqueeze(0).to(self.device)
# 獲取圖像特徵
with torch.no_grad():
image_features = self.model.encode_image(image_input)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
# 分析場景類型
scene_scores = self._analyze_scene_type(image_features)
# 分析光照條件
lighting_scores = self._analyze_lighting_condition(image_features)
# 文化場景的增強分析
cultural_analysis = {}
if include_cultural_analysis:
for scene_type in self.cultural_scene_prompts:
if scene_type in scene_scores and scene_scores[scene_type] > 0.2:
cultural_analysis[scene_type] = self._analyze_cultural_scene(
image_features, scene_type
)
specialized_analysis = {}
for scene_type in self.specialized_scene_prompts:
if scene_type in scene_scores and scene_scores[scene_type] > 0.2:
specialized_analysis[scene_type] = self._analyze_specialized_scene(
image_features, scene_type
)
viewpoint_scores = self._analyze_viewpoint(image_features)
object_combination_scores = self._analyze_object_combinations(image_features)
activity_scores = self._analyze_activities(image_features)
# display results
result = {
"scene_scores": scene_scores,
"top_scene": max(scene_scores.items(), key=lambda x: x[1]),
"lighting_condition": max(lighting_scores.items(), key=lambda x: x[1]),
"embedding": image_features.cpu().numpy().tolist()[0] if self.device == "cuda" else image_features.numpy().tolist()[0],
"viewpoint": max(viewpoint_scores.items(), key=lambda x: x[1]),
"object_combinations": sorted(object_combination_scores.items(), key=lambda x: x[1], reverse=True)[:3],
"activities": sorted(activity_scores.items(), key=lambda x: x[1], reverse=True)[:3]
}
if cultural_analysis:
result["cultural_analysis"] = cultural_analysis
if specialized_analysis:
result["specialized_analysis"] = specialized_analysis
return result
except Exception as e:
print(f"Error analyzing image with CLIP: {e}")
import traceback
traceback.print_exc()
return {"error": str(e)}
def _analyze_scene_type(self, image_features: torch.Tensor) -> Dict[str, float]:
"""分析圖像特徵與各場景類型的相似度"""
with torch.no_grad():
# 計算場景類型文本特徵
text_features = self.model.encode_text(self.scene_type_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T).softmax(dim=-1)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 建立場景分數字典
scene_scores = {}
for i, scene_type in enumerate(self.scene_type_prompts.keys()):
scene_scores[scene_type] = float(similarity[i])
return scene_scores
def _analyze_lighting_condition(self, image_features: torch.Tensor) -> Dict[str, float]:
"""分析圖像的光照條件"""
with torch.no_grad():
# 計算光照條件文本特徵
text_features = self.model.encode_text(self.lighting_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T).softmax(dim=-1)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 建立光照條件分數字典
lighting_scores = {}
for i, lighting_type in enumerate(self.lighting_condition_prompts.keys()):
lighting_scores[lighting_type] = float(similarity[i])
return lighting_scores
def _analyze_cultural_scene(self, image_features: torch.Tensor, scene_type: str) -> Dict[str, Any]:
"""針對特定文化場景進行深入分析"""
if scene_type not in self.cultural_tokens_dict:
return {"error": f"No cultural analysis available for {scene_type}"}
with torch.no_grad():
# 獲取特定文化場景的文本特徵
cultural_tokens = self.cultural_tokens_dict[scene_type]
text_features = self.model.encode_text(cultural_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 找到最匹配的文化描述
prompts = self.cultural_scene_prompts[scene_type]
scores = [(prompts[i], float(similarity[i])) for i in range(len(prompts))]
scores.sort(key=lambda x: x[1], reverse=True)
return {
"best_description": scores[0][0],
"confidence": scores[0][1],
"all_matches": scores
}
def _analyze_specialized_scene(self, image_features: torch.Tensor, scene_type: str) -> Dict[str, Any]:
"""針對特定專門場景進行深入分析"""
if scene_type not in self.specialized_tokens_dict:
return {"error": f"No specialized analysis available for {scene_type}"}
with torch.no_grad():
# 獲取特定專門場景的文本特徵
specialized_tokens = self.specialized_tokens_dict[scene_type]
text_features = self.model.encode_text(specialized_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 找到最匹配的專門描述
prompts = self.specialized_scene_prompts[scene_type]
scores = [(prompts[i], float(similarity[i])) for i in range(len(prompts))]
scores.sort(key=lambda x: x[1], reverse=True)
return {
"best_description": scores[0][0],
"confidence": scores[0][1],
"all_matches": scores
}
def _analyze_viewpoint(self, image_features: torch.Tensor) -> Dict[str, float]:
"""分析圖像的拍攝視角"""
with torch.no_grad():
# 計算視角文本特徵
text_features = self.model.encode_text(self.viewpoint_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T).softmax(dim=-1)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 建立視角分數字典
viewpoint_scores = {}
for i, viewpoint in enumerate(self.viewpoint_prompts.keys()):
viewpoint_scores[viewpoint] = float(similarity[i])
return viewpoint_scores
def _analyze_object_combinations(self, image_features: torch.Tensor) -> Dict[str, float]:
"""分析圖像中的物體組合"""
with torch.no_grad():
# 計算物體組合文本特徵
text_features = self.model.encode_text(self.object_combination_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T).softmax(dim=-1)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 建立物體組合分數字典
combination_scores = {}
for i, combination in enumerate(self.object_combination_prompts.keys()):
combination_scores[combination] = float(similarity[i])
return combination_scores
def _analyze_activities(self, image_features: torch.Tensor) -> Dict[str, float]:
"""分析圖像中的活動"""
with torch.no_grad():
# 計算活動文本特徵
text_features = self.model.encode_text(self.activity_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T).softmax(dim=-1)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 建立活動分數字典
activity_scores = {}
for i, activity in enumerate(self.activity_prompts.keys()):
activity_scores[activity] = float(similarity[i])
return activity_scores
def get_image_embedding(self, image) -> np.ndarray:
"""
獲取圖像的 CLIP 嵌入表示
Args:
image: PIL Image 或 numpy array
Returns:
np.ndarray: 圖像的 CLIP 特徵向量
"""
# 確保圖像是 PIL 格式
if not isinstance(image, Image.Image):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
else:
raise ValueError("Unsupported image format. Expected PIL Image or numpy array.")
# 預處理並編碼
image_input = self.preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad():
image_features = self.model.encode_image(image_input)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
# 轉換為 numpy 並返回
return image_features.cpu().numpy()[0] if self.device == "cuda" else image_features.numpy()[0]
def text_to_embedding(self, text: str) -> np.ndarray:
"""
將文本轉換為 CLIP 嵌入表示
Args:
text: 輸入文本
Returns:
np.ndarray: 文本的 CLIP 特徵向量
"""
text_token = clip.tokenize([text]).to(self.device)
with torch.no_grad():
text_features = self.model.encode_text(text_token)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
return text_features.cpu().numpy()[0] if self.device == "cuda" else text_features.numpy()[0]
def calculate_similarity(self, image, text_queries: List[str]) -> Dict[str, float]:
"""
計算圖像與多個文本查詢的相似度
Args:
image: PIL Image 或 numpy array
text_queries: 文本查詢列表
Returns:
Dict: 每個查詢的相似度分數
"""
# 獲取圖像嵌入
if isinstance(image, np.ndarray) and len(image.shape) == 1:
# 已經是嵌入向量
image_features = torch.tensor(image).unsqueeze(0).to(self.device)
else:
# 是圖像,需要提取嵌入
image_features = torch.tensor(self.get_image_embedding(image)).unsqueeze(0).to(self.device)
# calulate similarity
text_tokens = clip.tokenize(text_queries).to(self.device)
with torch.no_grad():
text_features = self.model.encode_text(text_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# display results
result = {}
for i, query in enumerate(text_queries):
result[query] = float(similarity[i])
return result