Spaces:
Runtime error
Runtime error
from typing import Dict, Optional, Tuple, Type | |
from pydantic import BaseModel, Field | |
import skimage.io | |
import torch | |
import torchvision | |
import torchxrayvision as xrv | |
from langchain_core.callbacks import ( | |
AsyncCallbackManagerForToolRun, | |
CallbackManagerForToolRun, | |
) | |
from langchain_core.tools import BaseTool | |
class ChestXRayInput(BaseModel): | |
"""Input for chest X-ray analysis tools. Only supports JPG or PNG images.""" | |
image_path: str = Field( | |
..., description="Path to the radiology image file, only supports JPG or PNG images" | |
) | |
class ChestXRayClassifierTool(BaseTool): | |
"""Tool that classifies chest X-ray images for multiple pathologies. | |
This tool uses a pre-trained DenseNet model to analyze chest X-ray images and | |
predict the likelihood of various pathologies. The model can classify the following 18 conditions: | |
Atelectasis, Cardiomegaly, Consolidation, Edema, Effusion, Emphysema, | |
Enlarged Cardiomediastinum, Fibrosis, Fracture, Hernia, Infiltration, | |
Lung Lesion, Lung Opacity, Mass, Nodule, Pleural Thickening, Pneumonia, Pneumothorax | |
The output values represent the probability (from 0 to 1) of each condition being present in the image. | |
A higher value indicates a higher likelihood of the condition being present. | |
""" | |
name: str = "chest_xray_classifier" | |
description: str = ( | |
"A tool that analyzes chest X-ray images and classifies them for 18 different pathologies. " | |
"Input should be the path to a chest X-ray image file. " | |
"Output is a dictionary of pathologies and their predicted probabilities (0 to 1). " | |
"Pathologies include: Atelectasis, Cardiomegaly, Consolidation, Edema, Effusion, Emphysema, " | |
"Enlarged Cardiomediastinum, Fibrosis, Fracture, Hernia, Infiltration, Lung Lesion, " | |
"Lung Opacity, Mass, Nodule, Pleural Thickening, Pneumonia, and Pneumothorax. " | |
"Higher values indicate a higher likelihood of the condition being present." | |
) | |
args_schema: Type[BaseModel] = ChestXRayInput | |
model: xrv.models.DenseNet = None | |
device: Optional[str] = "cuda" | |
transform: torchvision.transforms.Compose = None | |
def __init__(self, model_name: str = "densenet121-res224-all", device: Optional[str] = "cuda"): | |
super().__init__() | |
self.model = xrv.models.DenseNet(weights=model_name) | |
self.model.eval() | |
self.device = torch.device(device) if device else "cuda" | |
self.model = self.model.to(self.device) | |
self.transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop()]) | |
def _process_image(self, image_path: str) -> torch.Tensor: | |
""" | |
Process the input chest X-ray image for model inference. | |
This method loads the image, normalizes it, applies necessary transformations, | |
and prepares it as a torch.Tensor for model input. | |
Args: | |
image_path (str): The file path to the chest X-ray image. | |
Returns: | |
torch.Tensor: A processed image tensor ready for model inference. | |
Raises: | |
FileNotFoundError: If the specified image file does not exist. | |
ValueError: If the image cannot be properly loaded or processed. | |
""" | |
img = skimage.io.imread(image_path) | |
img = xrv.datasets.normalize(img, 255) | |
if len(img.shape) > 2: | |
img = img[:, :, 0] | |
img = img[None, :, :] | |
img = self.transform(img) | |
img = torch.from_numpy(img).unsqueeze(0) | |
img = img.to(self.device) | |
return img | |
def _run( | |
self, | |
image_path: str, | |
run_manager: Optional[CallbackManagerForToolRun] = None, | |
) -> Tuple[Dict[str, float], Dict]: | |
"""Classify the chest X-ray image for multiple pathologies. | |
Args: | |
image_path (str): The path to the chest X-ray image file. | |
run_manager (Optional[CallbackManagerForToolRun]): The callback manager for the tool run. | |
Returns: | |
Tuple[Dict[str, float], Dict]: A tuple containing the classification results | |
(pathologies and their probabilities from 0 to 1) | |
and any additional metadata. | |
Raises: | |
Exception: If there's an error processing the image or during classification. | |
""" | |
try: | |
img = self._process_image(image_path) | |
with torch.inference_mode(): | |
preds = self.model(img).cpu()[0] | |
output = dict(zip(xrv.datasets.default_pathologies, preds.numpy())) | |
metadata = { | |
"image_path": image_path, | |
"analysis_status": "completed", | |
"note": "Probabilities range from 0 to 1, with higher values indicating higher likelihood of the condition.", | |
} | |
return output, metadata | |
except Exception as e: | |
return {"error": str(e)}, { | |
"image_path": image_path, | |
"analysis_status": "failed", | |
} | |
async def _arun( | |
self, | |
image_path: str, | |
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, | |
) -> Tuple[Dict[str, float], Dict]: | |
"""Asynchronously classify the chest X-ray image for multiple pathologies. | |
This method currently calls the synchronous version, as the model inference | |
is not inherently asynchronous. For true asynchronous behavior, consider | |
using a separate thread or process. | |
Args: | |
image_path (str): The path to the chest X-ray image file. | |
run_manager (Optional[AsyncCallbackManagerForToolRun]): The async callback manager for the tool run. | |
Returns: | |
Tuple[Dict[str, float], Dict]: A tuple containing the classification results | |
(pathologies and their probabilities from 0 to 1) | |
and any additional metadata. | |
Raises: | |
Exception: If there's an error processing the image or during classification. | |
""" | |
return self._run(image_path) | |