Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import torchvision.transforms as transforms | |
import numpy as np | |
from PIL import Image | |
from typing import Dict, List, Tuple, Optional, Any | |
import logging | |
class Places365Model: | |
""" | |
Places365 scene classification model wrapper for scene understanding integration. | |
Provides scene classification and scene attribute prediction capabilities. | |
""" | |
def __init__(self, model_name: str = 'resnet50_places365', device: Optional[str] = None): | |
""" | |
Initialize Places365 model with configurable architecture and device. | |
Args: | |
model_name: Model architecture name (默認 resnet50) | |
device: Target device for inference (auto-detected if None) | |
""" | |
self.logger = logging.getLogger(self.__class__.__name__) | |
# Device configuration with fallback logic | |
if device is None: | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
else: | |
self.device = device | |
self.model_name = model_name | |
self.model = None | |
self.scene_classes = [] | |
self.scene_attributes = [] | |
# Model configuration mapping | |
self.model_configs = { | |
'resnet18_places365': { | |
'arch': 'resnet18', | |
'num_classes': 365, | |
'url': 'http://places2.csail.mit.edu/models_places365/resnet18_places365.pth.tar' | |
}, | |
'resnet50_places365': { | |
'arch': 'resnet50', | |
'num_classes': 365, | |
'url': 'http://places2.csail.mit.edu/models_places365/resnet50_places365.pth.tar' | |
}, | |
'densenet161_places365': { | |
'arch': 'densenet161', | |
'num_classes': 365, | |
'url': 'http://places2.csail.mit.edu/models_places365/densenet161_places365.pth.tar' | |
} | |
} | |
self._load_model() | |
self._load_class_names() | |
self._setup_scene_mapping() | |
def _load_model(self): | |
"""載入與初始化 Places365 model""" | |
try: | |
if self.model_name not in self.model_configs: | |
raise ValueError(f"Unsupported model name: {self.model_name}") | |
config = self.model_configs[self.model_name] | |
# Import model architecture | |
if config['arch'].startswith('resnet'): | |
import torchvision.models as models | |
if config['arch'] == 'resnet18': | |
self.model = models.resnet18(num_classes=config['num_classes']) | |
elif config['arch'] == 'resnet50': | |
self.model = models.resnet50(num_classes=config['num_classes']) | |
elif config['arch'] == 'densenet161': | |
import torchvision.models as models | |
self.model = models.densenet161(num_classes=config['num_classes']) | |
# Load pretrained weights | |
checkpoint = torch.hub.load_state_dict_from_url( | |
config['url'], | |
map_location=self.device, | |
progress=True | |
) | |
# Handle different checkpoint formats | |
if 'state_dict' in checkpoint: | |
state_dict = checkpoint['state_dict'] | |
# Remove 'module.' prefix if present | |
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} | |
else: | |
state_dict = checkpoint | |
self.model.load_state_dict(state_dict) | |
self.model.to(self.device) | |
self.model.eval() | |
self.logger.info(f"Places365 model {self.model_name} loaded successfully on {self.device}") | |
except Exception as e: | |
self.logger.error(f"Error loading Places365 model: {str(e)}") | |
raise | |
def _load_class_names(self): | |
"""Load Places365 class names and scene attributes.""" | |
try: | |
# Load scene class names (365 categories) | |
import urllib.request | |
class_url = 'https://raw.githubusercontent.com/csailvision/places365/master/categories_places365.txt' | |
class_file = urllib.request.urlopen(class_url) | |
self.scene_classes = [] | |
for line in class_file: | |
class_name = line.decode('utf-8').strip().split(' ')[0][3:] # Remove /x/ prefix | |
self.scene_classes.append(class_name) | |
# Load scene attributes (optional, for enhanced description) | |
attr_url = 'https://raw.githubusercontent.com/csailvision/places365/master/labels_sunattribute.txt' | |
try: | |
attr_file = urllib.request.urlopen(attr_url) | |
self.scene_attributes = [] | |
for line in attr_file: | |
attr_name = line.decode('utf-8').strip() | |
self.scene_attributes.append(attr_name) | |
except: | |
self.logger.warning("Scene attributes not loaded, continuing with basic classification") | |
self.scene_attributes = [] | |
self.logger.info(f"Loaded {len(self.scene_classes)} scene classes and {len(self.scene_attributes)} attributes") | |
except Exception as e: | |
self.logger.error(f"Error loading class names: {str(e)}") | |
# Fallback to basic class names if download fails | |
self.scene_classes = [f"scene_class_{i}" for i in range(365)] | |
self.scene_attributes = [] | |
def _setup_scene_mapping(self): | |
"""Setup mapping from Places365 classes to common scene types.""" | |
# 建立Places365類別到通用場景類型的映射關係 | |
self.scene_type_mapping = { | |
# Indoor scenes | |
'living_room': 'living_room', | |
'bedroom': 'bedroom', | |
'kitchen': 'kitchen', | |
'dining_room': 'dining_area', | |
'bathroom': 'bathroom', | |
'office': 'office_workspace', | |
'conference_room': 'office_workspace', | |
'classroom': 'educational_setting', | |
'library': 'library', | |
'restaurant': 'restaurant', | |
'cafe': 'cafe', | |
'bar': 'bar', | |
'hotel_room': 'hotel_room', | |
'hospital_room': 'medical_facility', | |
'gym': 'gym', | |
'supermarket': 'retail_store', | |
'clothing_store': 'retail_store', | |
# Outdoor urban scenes | |
'street': 'city_street', | |
'crosswalk': 'intersection', | |
'parking_lot': 'parking_lot', | |
'gas_station': 'gas_station', | |
'bus_station': 'bus_stop', | |
'train_station': 'train_station', | |
'airport_terminal': 'airport', | |
'subway_station': 'subway_station', | |
'bridge': 'bridge', | |
'highway': 'highway', | |
'downtown': 'commercial_district', | |
'shopping_mall': 'shopping_mall', | |
# Natural outdoor scenes | |
'park': 'park_area', | |
'beach': 'beach', | |
'forest': 'forest', | |
'mountain': 'mountain', | |
'lake': 'lake', | |
'river': 'river', | |
'ocean': 'ocean', | |
'desert': 'desert', | |
'field': 'field', | |
'garden': 'garden', | |
# Landmark and tourist areas | |
'castle': 'historical_monument', | |
'palace': 'historical_monument', | |
'temple': 'temple', | |
'church': 'church', | |
'mosque': 'mosque', | |
'museum': 'museum', | |
'art_gallery': 'art_gallery', | |
'tower': 'tourist_landmark', | |
'monument': 'historical_monument', | |
# Sports and entertainment | |
'stadium': 'stadium', | |
'basketball_court': 'sports_field', | |
'tennis_court': 'sports_field', | |
'swimming_pool': 'swimming_pool', | |
'playground': 'playground', | |
'amusement_park': 'amusement_park', | |
'theater': 'theater', | |
'concert_hall': 'concert_hall', | |
# Transportation | |
'airplane_cabin': 'airplane_cabin', | |
'train_interior': 'train_interior', | |
'car_interior': 'car_interior', | |
# Construction and industrial | |
'construction_site': 'construction_site', | |
'factory': 'factory', | |
'warehouse': 'warehouse' | |
} | |
# Indoor/outdoor classification helper | |
self.indoor_classes = { | |
'living_room', 'bedroom', 'kitchen', 'dining_room', 'bathroom', 'office', | |
'conference_room', 'classroom', 'library', 'restaurant', 'cafe', 'bar', | |
'hotel_room', 'hospital_room', 'gym', 'supermarket', 'clothing_store', | |
'airplane_cabin', 'train_interior', 'car_interior', 'theater', 'concert_hall', | |
'museum', 'art_gallery', 'shopping_mall' | |
} | |
self.outdoor_classes = { | |
'street', 'crosswalk', 'parking_lot', 'gas_station', 'bus_station', | |
'train_station', 'airport_terminal', 'bridge', 'highway', 'downtown', | |
'park', 'beach', 'forest', 'mountain', 'lake', 'river', 'ocean', | |
'desert', 'field', 'garden', 'stadium', 'basketball_court', 'tennis_court', | |
'swimming_pool', 'playground', 'amusement_park', 'construction_site', | |
'factory', 'warehouse', 'castle', 'palace', 'temple', 'church', 'mosque', | |
'tower', 'monument' | |
} | |
def preprocess(self, image_pil: Image.Image) -> torch.Tensor: | |
""" | |
Preprocess PIL image for Places365 model inference. | |
Args: | |
image_pil: Input PIL image | |
Returns: | |
torch.Tensor: Preprocessed image tensor | |
""" | |
# Places365 standard preprocessing | |
transform = transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
# Convert to RGB if needed | |
if image_pil.mode != 'RGB': | |
image_pil = image_pil.convert('RGB') | |
# Apply preprocessing | |
input_tensor = transform(image_pil).unsqueeze(0) | |
return input_tensor.to(self.device) | |
def predict(self, image_pil: Image.Image) -> Dict[str, Any]: | |
""" | |
Predict scene classification and attributes for input image. | |
Args: | |
image_pil: Input PIL image | |
Returns: | |
Dict containing scene predictions and confidence scores | |
""" | |
try: | |
# Preprocess image | |
input_tensor = self.preprocess(image_pil) | |
# Model inference | |
with torch.no_grad(): | |
outputs = self.model(input_tensor) | |
probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
# 返回最有可能的項目 | |
top_k = min(10, len(self.scene_classes)) # Configurable top-k | |
top_probs, top_indices = torch.topk(probabilities, top_k, dim=1) | |
# Extract results | |
top_probs = top_probs.cpu().numpy()[0] | |
top_indices = top_indices.cpu().numpy()[0] | |
# Build prediction results | |
predictions = [] | |
for i in range(top_k): | |
class_idx = top_indices[i] | |
confidence = float(top_probs[i]) | |
scene_class = self.scene_classes[class_idx] | |
predictions.append({ | |
'class_name': scene_class, | |
'class_index': class_idx, | |
'confidence': confidence | |
}) | |
# Get primary prediction | |
primary_prediction = predictions[0] | |
primary_class = primary_prediction['class_name'] | |
# 確認是 indoor/outdoor | |
is_indoor = self._classify_indoor_outdoor(primary_class) | |
# Map to common scene type | |
mapped_scene_type = self._map_places365_to_scene_types(primary_class) | |
# Determine scene attributes (basic inference based on class) | |
scene_attributes = self._infer_scene_attributes(primary_class) | |
result = { | |
'scene_label': primary_class, | |
'mapped_scene_type': mapped_scene_type, | |
'confidence': primary_prediction['confidence'], | |
'is_indoor': is_indoor, | |
'attributes': scene_attributes, | |
'top_predictions': predictions, | |
'all_probabilities': probabilities.cpu().numpy()[0].tolist() | |
} | |
return result | |
except Exception as e: | |
self.logger.error(f"Error in Places365 prediction: {str(e)}") | |
return { | |
'scene_label': 'unknown', | |
'mapped_scene_type': 'unknown', | |
'confidence': 0.0, | |
'is_indoor': None, | |
'attributes': [], | |
'top_predictions': [], | |
'error': str(e) | |
} | |
def _classify_indoor_outdoor(self, scene_class: str) -> Optional[bool]: | |
""" | |
Classify if scene is indoor or outdoor based on Places365 class. | |
Args: | |
scene_class: Places365 scene class name | |
Returns: | |
bool or None: True for indoor, False for outdoor, None if uncertain | |
""" | |
if scene_class in self.indoor_classes: | |
return True | |
elif scene_class in self.outdoor_classes: | |
return False | |
else: | |
# For ambiguous classes, use heuristics | |
indoor_keywords = ['room', 'office', 'store', 'shop', 'hall', 'interior', 'indoor'] | |
outdoor_keywords = ['street', 'road', 'park', 'field', 'beach', 'mountain', 'outdoor'] | |
scene_lower = scene_class.lower() | |
if any(keyword in scene_lower for keyword in indoor_keywords): | |
return True | |
elif any(keyword in scene_lower for keyword in outdoor_keywords): | |
return False | |
else: | |
return None | |
def _map_places365_to_scene_types(self, places365_class: str) -> str: | |
""" | |
Map Places365 class to common scene type used by the system. | |
Args: | |
places365_class: Places365 scene class name | |
Returns: | |
str: Mapped scene type | |
""" | |
# Direct mapping lookup | |
if places365_class in self.scene_type_mapping: | |
return self.scene_type_mapping[places365_class] | |
# Fuzzy matching for similar classes | |
places365_lower = places365_class.lower() | |
# Indoor fuzzy matching | |
if any(keyword in places365_lower for keyword in ['living', 'bedroom', 'kitchen']): | |
return 'general_indoor_space' | |
elif any(keyword in places365_lower for keyword in ['office', 'conference', 'meeting']): | |
return 'office_workspace' | |
elif any(keyword in places365_lower for keyword in ['dining', 'restaurant', 'cafe']): | |
return 'dining_area' | |
elif any(keyword in places365_lower for keyword in ['store', 'shop', 'market']): | |
return 'retail_store' | |
elif any(keyword in places365_lower for keyword in ['school', 'class', 'library']): | |
return 'educational_setting' | |
# Outdoor fuzzy matching | |
elif any(keyword in places365_lower for keyword in ['street', 'road', 'crosswalk']): | |
return 'city_street' | |
elif any(keyword in places365_lower for keyword in ['park', 'garden', 'plaza']): | |
return 'park_area' | |
elif any(keyword in places365_lower for keyword in ['beach', 'ocean', 'lake']): | |
return 'beach' | |
elif any(keyword in places365_lower for keyword in ['mountain', 'forest', 'desert']): | |
return 'natural_outdoor_area' | |
elif any(keyword in places365_lower for keyword in ['parking', 'garage']): | |
return 'parking_lot' | |
elif any(keyword in places365_lower for keyword in ['station', 'terminal', 'airport']): | |
return 'transportation_hub' | |
# Landmark fuzzy matching | |
elif any(keyword in places365_lower for keyword in ['castle', 'palace', 'monument', 'temple']): | |
return 'historical_monument' | |
elif any(keyword in places365_lower for keyword in ['tower', 'landmark']): | |
return 'tourist_landmark' | |
elif any(keyword in places365_lower for keyword in ['museum', 'gallery']): | |
return 'cultural_venue' | |
# Default fallback based on indoor/outdoor | |
is_indoor = self._classify_indoor_outdoor(places365_class) | |
if is_indoor is True: | |
return 'general_indoor_space' | |
elif is_indoor is False: | |
return 'generic_street_view' | |
else: | |
return 'unknown' | |
def _infer_scene_attributes(self, scene_class: str) -> List[str]: | |
""" | |
Infer basic scene attributes from Places365 class. | |
Args: | |
scene_class: Places365 scene class name | |
Returns: | |
List[str]: Inferred scene attributes | |
""" | |
attributes = [] | |
scene_lower = scene_class.lower() | |
# Lighting attributes | |
if any(keyword in scene_lower for keyword in ['outdoor', 'street', 'park', 'beach']): | |
attributes.append('natural_lighting') | |
elif any(keyword in scene_lower for keyword in ['indoor', 'room', 'office']): | |
attributes.append('artificial_lighting') | |
# Functional attributes | |
if any(keyword in scene_lower for keyword in ['commercial', 'store', 'shop', 'restaurant']): | |
attributes.append('commercial') | |
elif any(keyword in scene_lower for keyword in ['residential', 'home', 'living', 'bedroom']): | |
attributes.append('residential') | |
elif any(keyword in scene_lower for keyword in ['office', 'conference', 'meeting']): | |
attributes.append('workplace') | |
elif any(keyword in scene_lower for keyword in ['recreation', 'park', 'playground', 'stadium']): | |
attributes.append('recreational') | |
elif any(keyword in scene_lower for keyword in ['educational', 'school', 'library', 'classroom']): | |
attributes.append('educational') | |
# Spatial attributes | |
if any(keyword in scene_lower for keyword in ['open', 'field', 'plaza', 'stadium']): | |
attributes.append('open_space') | |
elif any(keyword in scene_lower for keyword in ['enclosed', 'room', 'interior']): | |
attributes.append('enclosed_space') | |
return attributes | |
def get_scene_probabilities(self, image_pil: Image.Image) -> Dict[str, float]: | |
""" | |
Get probability distribution over all scene classes. | |
Args: | |
image_pil: Input PIL image | |
Returns: | |
Dict mapping scene class names to probabilities | |
""" | |
try: | |
input_tensor = self.preprocess(image_pil) | |
with torch.no_grad(): | |
outputs = self.model(input_tensor) | |
probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
probs = probabilities.cpu().numpy()[0] | |
return { | |
self.scene_classes[i]: float(probs[i]) | |
for i in range(len(self.scene_classes)) | |
} | |
except Exception as e: | |
self.logger.error(f"Error getting scene probabilities: {str(e)}") | |
return {} | |