astro-seg / modeling_yolo.py
rayh's picture
Add modeling_yolo.py
e3c079a verified
"""YOLO model for Hugging Face Transformers."""
import torch
from pathlib import Path
from typing import Dict, Any, Union
import numpy as np
import logging
from ultralytics import YOLO
logger = logging.getLogger(__name__)
class YOLOSegmentationPipeline:
"""YOLO segmentation pipeline for Hugging Face Hub."""
def __init__(self, model_path: Union[str, Path], **kwargs):
"""Initialize the pipeline with model path."""
self.model_path = str(model_path)
self.model = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.load_model()
def load_model(self):
"""Load the YOLO model."""
logger.info(f"Loading model from {self.model_path}")
self.model = YOLO(self.model_path)
self.model.to(self.device)
self.model.eval()
logger.info(f"Model loaded on {self.device}")
def __call__(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"""
Run inference on input image.
Args:
inputs: Dictionary containing 'image' (PIL Image)
**kwargs: Additional inference parameters
Returns:
Dictionary with 'predictions' key containing detection results
"""
from PIL import Image
# Get input image
image = inputs.get("image")
if image is None:
raise ValueError("Input must contain 'image' key with PIL Image")
# Convert to RGB if needed
if image.mode != "RGB":
image = image.convert("RGB")
# Run inference
with torch.no_grad():
results = self.model(image, **kwargs)
# Process results
return self._format_results(results[0])
def _format_results(self, result) -> Dict[str, Any]:
"""Format YOLO results for Hugging Face API."""
# Get boxes if available
if hasattr(result, 'boxes') and result.boxes is not None:
boxes = result.boxes.xyxy.cpu().numpy()
scores = result.boxes.conf.cpu().numpy()
labels = result.boxes.cls.cpu().numpy().astype(int)
else:
boxes = np.zeros((0, 4))
scores = np.zeros(0)
labels = np.zeros(0, dtype=int)
# Get masks if available
if hasattr(result, 'masks') and result.masks is not None:
masks = result.masks.data.cpu().numpy()
else:
masks = np.zeros((0, *result.orig_shape))
# Format predictions
predictions = []
for i, (box, score, label) in enumerate(zip(boxes, scores, labels)):
prediction = {
'box': box.tolist(),
'score': float(score),
'label': int(label),
'mask': masks[i].tolist() if i < len(masks) else None
}
predictions.append(prediction)
return {'predictions': predictions}