Spaces:
Runtime error
Runtime error
from typing import Dict, List, Optional, Tuple, Type, Any | |
from pathlib import Path | |
import uuid | |
import tempfile | |
import numpy as np | |
import torch | |
import torchvision | |
import torchxrayvision as xrv | |
import matplotlib.pyplot as plt | |
import skimage.io | |
import skimage.measure | |
import skimage.transform | |
import traceback | |
from pydantic import BaseModel, Field | |
from langchain_core.callbacks import ( | |
AsyncCallbackManagerForToolRun, | |
CallbackManagerForToolRun, | |
) | |
from langchain_core.tools import BaseTool | |
class ChestXRaySegmentationInput(BaseModel): | |
"""Input schema for the Chest X-Ray Segmentation Tool.""" | |
image_path: str = Field(..., description="Path to the chest X-ray image file to be segmented") | |
organs: Optional[List[str]] = Field( | |
None, | |
description="List of organs to segment. If None, all available organs will be segmented. " | |
"Available organs: Left/Right Clavicle, Left/Right Scapula, Left/Right Lung, " | |
"Left/Right Hilus Pulmonis, Heart, Aorta, Facies Diaphragmatica, " | |
"Mediastinum, Weasand, Spine" | |
) | |
class OrganMetrics(BaseModel): | |
"""Detailed metrics for a segmented organ.""" | |
# Basic metrics | |
area_pixels: int = Field(..., description="Area in pixels") | |
area_cm2: float = Field(..., description="Approximate area in cm²") | |
centroid: Tuple[float, float] = Field(..., description="(y, x) coordinates of centroid") | |
bbox: Tuple[int, int, int, int] = Field( | |
..., description="Bounding box coordinates (min_y, min_x, max_y, max_x)" | |
) | |
# Size metrics | |
width: int = Field(..., description="Width of the organ in pixels") | |
height: int = Field(..., description="Height of the organ in pixels") | |
aspect_ratio: float = Field(..., description="Height/width ratio") | |
# Position metrics | |
relative_position: Dict[str, float] = Field( | |
..., description="Position relative to image boundaries (0-1 scale)" | |
) | |
# Analysis metrics | |
mean_intensity: float = Field(..., description="Mean pixel intensity in the organ region") | |
std_intensity: float = Field(..., description="Standard deviation of pixel intensity") | |
confidence_score: float = Field(..., description="Model confidence score for this organ") | |
class ChestXRaySegmentationTool(BaseTool): | |
"""Tool for performing detailed segmentation analysis of chest X-ray images.""" | |
name: str = "chest_xray_segmentation" | |
description: str = ( | |
"Segments chest X-ray images to specified anatomical structures. " | |
"Available organs: Left/Right Clavicle (collar bones), Left/Right Scapula (shoulder blades), " | |
"Left/Right Lung, Left/Right Hilus Pulmonis (lung roots), Heart, Aorta, " | |
"Facies Diaphragmatica (diaphragm), Mediastinum (central cavity), Weasand (esophagus), " | |
"and Spine. Returns segmentation visualization and comprehensive metrics. " | |
"Let the user know the area is not accurate unless input has been DICOM." | |
) | |
args_schema: Type[BaseModel] = ChestXRaySegmentationInput | |
model: Any = None | |
device: Optional[str] = "cuda" | |
transform: Any = None | |
pixel_spacing_mm: float = 0.2 | |
temp_dir: Path = Path("temp") | |
organ_map: Dict[str, int] = None | |
def __init__(self, device: Optional[str] = "cuda"): | |
"""Initialize the segmentation tool with model and temporary directory.""" | |
super().__init__() | |
self.model = xrv.baseline_models.chestx_det.PSPNet() | |
self.device = torch.device(device) if device else "cuda" | |
self.model = self.model.to(self.device) | |
self.model.eval() | |
self.transform = torchvision.transforms.Compose( | |
[xrv.datasets.XRayCenterCrop(), xrv.datasets.XRayResizer(512)] | |
) | |
self.temp_dir = Path("temp") if Path("temp").exists() else Path(tempfile.mkdtemp()) | |
self.temp_dir.mkdir(exist_ok=True) | |
# Map friendly names to model target indices | |
self.organ_map = { | |
"Left Clavicle": 0, "Right Clavicle": 1, | |
"Left Scapula": 2, "Right Scapula": 3, | |
"Left Lung": 4, "Right Lung": 5, | |
"Left Hilus Pulmonis": 6, "Right Hilus Pulmonis": 7, | |
"Heart": 8, "Aorta": 9, | |
"Facies Diaphragmatica": 10, "Mediastinum": 11, | |
"Weasand": 12, "Spine": 13 | |
} | |
def _compute_organ_metrics( | |
self, mask: np.ndarray, original_img: np.ndarray, confidence: float | |
) -> Optional[OrganMetrics]: | |
"""Compute comprehensive metrics for a single organ mask.""" | |
# Resize mask to match original image if needed | |
if mask.shape != original_img.shape: | |
mask = skimage.transform.resize( | |
mask, original_img.shape, order=0, preserve_range=True, anti_aliasing=False | |
) | |
props = skimage.measure.regionprops(mask.astype(int)) | |
if not props: | |
return None | |
props = props[0] | |
area_cm2 = mask.sum() * (self.pixel_spacing_mm / 10) ** 2 | |
img_height, img_width = mask.shape | |
cy, cx = props.centroid | |
relative_pos = { | |
"top": cy / img_height, | |
"left": cx / img_width, | |
"center_dist": np.sqrt(((cy / img_height - 0.5) ** 2 + (cx / img_width - 0.5) ** 2)), | |
} | |
organ_pixels = original_img[mask > 0] | |
mean_intensity = organ_pixels.mean() if len(organ_pixels) > 0 else 0 | |
std_intensity = organ_pixels.std() if len(organ_pixels) > 0 else 0 | |
return OrganMetrics( | |
area_pixels=int(mask.sum()), | |
area_cm2=float(area_cm2), | |
centroid=(float(cy), float(cx)), | |
bbox=tuple(map(int, props.bbox)), | |
width=int(props.bbox[3] - props.bbox[1]), | |
height=int(props.bbox[2] - props.bbox[0]), | |
aspect_ratio=float( | |
(props.bbox[2] - props.bbox[0]) / max(1, props.bbox[3] - props.bbox[1]) | |
), | |
relative_position=relative_pos, | |
mean_intensity=float(mean_intensity), | |
std_intensity=float(std_intensity), | |
confidence_score=float(confidence), | |
) | |
def _save_visualization( | |
self, original_img: np.ndarray, pred_masks: torch.Tensor, organ_indices: List[int] | |
) -> str: | |
"""Save visualization of original image with segmentation masks overlaid.""" | |
# Initialize plot and base image | |
plt.figure(figsize=(10, 10)) | |
plt.imshow(original_img, cmap='gray', extent=[0, original_img.shape[1], original_img.shape[0], 0]) | |
# Generate color palette for organs | |
colors = plt.cm.rainbow(np.linspace(0, 1, len(organ_indices))) | |
# Process and overlay each organ mask | |
for idx, (organ_idx, color) in enumerate(zip(organ_indices, colors)): | |
mask = pred_masks[0, organ_idx].cpu().numpy() | |
if mask.sum() > 0: | |
# Resize mask if dimensions don't match | |
if mask.shape != original_img.shape: | |
mask = skimage.transform.resize( | |
mask, | |
original_img.shape, | |
order=0, | |
preserve_range=True, | |
anti_aliasing=False | |
) | |
# Erode mask to make it 10% smaller | |
import cv2 | |
kernel_size = max(1, int(min(mask.shape) * 0.04)) | |
kernel = np.ones((kernel_size, kernel_size), np.uint8) | |
mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=1) | |
# Apply semi-transparent colored overlay | |
colored_mask = np.zeros((*original_img.shape, 4)) | |
colored_mask[mask > 0] = (*color[:3], 0.3) | |
plt.imshow(colored_mask, extent=[0, original_img.shape[1], original_img.shape[0], 0]) | |
# Add legend entry for organ | |
organ_name = list(self.organ_map.keys())[list(self.organ_map.values()).index(organ_idx)] | |
plt.plot([], [], color=color, label=organ_name, linewidth=3) | |
# Finalize and save plot | |
plt.title("Segmentation Overlay") | |
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left') | |
plt.axis("off") | |
save_path = self.temp_dir / f"segmentation_{uuid.uuid4().hex[:8]}.png" | |
plt.savefig(save_path, bbox_inches='tight', dpi=300) | |
plt.close() | |
return str(save_path) | |
def _run( | |
self, | |
image_path: str, | |
organs: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForToolRun] = None, | |
) -> Tuple[Dict[str, Any], Dict]: | |
"""Run segmentation analysis for specified organs.""" | |
try: | |
# Validate and get organ indices | |
if organs: | |
# Normalize organ names and validate | |
organs = [o.strip() for o in organs] | |
invalid_organs = [o for o in organs if o not in self.organ_map] | |
if invalid_organs: | |
raise ValueError(f"Invalid organs specified: {invalid_organs}") | |
organ_indices = [self.organ_map[o] for o in organs] | |
else: | |
# Use all organs if none specified | |
organ_indices = list(self.organ_map.values()) | |
organs = list(self.organ_map.keys()) | |
# Load and process image | |
original_img = skimage.io.imread(image_path) | |
if len(original_img.shape) > 2: | |
original_img = original_img[:, :, 0] | |
img = xrv.datasets.normalize(original_img, 255) | |
img = img[None, ...] | |
img = self.transform(img) | |
img = torch.from_numpy(img) | |
img = img.to(self.device) | |
# Generate predictions | |
with torch.no_grad(): | |
pred = self.model(img) | |
pred_probs = torch.sigmoid(pred) | |
pred_masks = (pred_probs > 0.5).float() | |
# Save visualization | |
viz_path = self._save_visualization(original_img, pred_masks, organ_indices) | |
# Compute metrics for selected organs | |
results = {} | |
for idx, organ_name in zip(organ_indices, organs): | |
mask = pred_masks[0, idx].cpu().numpy() | |
if mask.sum() > 0: | |
metrics = self._compute_organ_metrics( | |
mask, original_img, float(pred_probs[0, idx].mean().cpu()) | |
) | |
if metrics: | |
results[organ_name] = metrics | |
output = { | |
"segmentation_image_path": viz_path, | |
"metrics": {organ: metrics.dict() for organ, metrics in results.items()}, | |
} | |
metadata = { | |
"image_path": image_path, | |
"segmentation_image_path": viz_path, | |
"original_size": original_img.shape, | |
"model_size": tuple(img.shape[-2:]), | |
"pixel_spacing_mm": self.pixel_spacing_mm, | |
"requested_organs": organs, | |
"processed_organs": list(results.keys()), | |
"analysis_status": "completed", | |
} | |
return output, metadata | |
except Exception as e: | |
error_output = {"error": str(e)} | |
error_metadata = { | |
"image_path": image_path, | |
"analysis_status": "failed", | |
"error_traceback": traceback.format_exc(), | |
} | |
return error_output, error_metadata | |
async def _arun( | |
self, | |
image_path: str, | |
organs: Optional[List[str]] = None, | |
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, | |
) -> Tuple[Dict[str, Any], Dict]: | |
"""Async version of _run.""" | |
return self._run(image_path, organs) |