medrax.org / medrax /tools /classification.py
oldcai's picture
Upload folder using huggingface_hub
d7a7846 verified
from typing import Dict, Optional, Tuple, Type
from pydantic import BaseModel, Field
import skimage.io
import torch
import torchvision
import torchxrayvision as xrv
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.tools import BaseTool
class ChestXRayInput(BaseModel):
"""Input for chest X-ray analysis tools. Only supports JPG or PNG images."""
image_path: str = Field(
..., description="Path to the radiology image file, only supports JPG or PNG images"
)
class ChestXRayClassifierTool(BaseTool):
"""Tool that classifies chest X-ray images for multiple pathologies.
This tool uses a pre-trained DenseNet model to analyze chest X-ray images and
predict the likelihood of various pathologies. The model can classify the following 18 conditions:
Atelectasis, Cardiomegaly, Consolidation, Edema, Effusion, Emphysema,
Enlarged Cardiomediastinum, Fibrosis, Fracture, Hernia, Infiltration,
Lung Lesion, Lung Opacity, Mass, Nodule, Pleural Thickening, Pneumonia, Pneumothorax
The output values represent the probability (from 0 to 1) of each condition being present in the image.
A higher value indicates a higher likelihood of the condition being present.
"""
name: str = "chest_xray_classifier"
description: str = (
"A tool that analyzes chest X-ray images and classifies them for 18 different pathologies. "
"Input should be the path to a chest X-ray image file. "
"Output is a dictionary of pathologies and their predicted probabilities (0 to 1). "
"Pathologies include: Atelectasis, Cardiomegaly, Consolidation, Edema, Effusion, Emphysema, "
"Enlarged Cardiomediastinum, Fibrosis, Fracture, Hernia, Infiltration, Lung Lesion, "
"Lung Opacity, Mass, Nodule, Pleural Thickening, Pneumonia, and Pneumothorax. "
"Higher values indicate a higher likelihood of the condition being present."
)
args_schema: Type[BaseModel] = ChestXRayInput
model: xrv.models.DenseNet = None
device: Optional[str] = "cuda"
transform: torchvision.transforms.Compose = None
def __init__(self, model_name: str = "densenet121-res224-all", device: Optional[str] = "cuda"):
super().__init__()
self.model = xrv.models.DenseNet(weights=model_name)
self.model.eval()
self.device = torch.device(device) if device else "cuda"
self.model = self.model.to(self.device)
self.transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop()])
def _process_image(self, image_path: str) -> torch.Tensor:
"""
Process the input chest X-ray image for model inference.
This method loads the image, normalizes it, applies necessary transformations,
and prepares it as a torch.Tensor for model input.
Args:
image_path (str): The file path to the chest X-ray image.
Returns:
torch.Tensor: A processed image tensor ready for model inference.
Raises:
FileNotFoundError: If the specified image file does not exist.
ValueError: If the image cannot be properly loaded or processed.
"""
img = skimage.io.imread(image_path)
img = xrv.datasets.normalize(img, 255)
if len(img.shape) > 2:
img = img[:, :, 0]
img = img[None, :, :]
img = self.transform(img)
img = torch.from_numpy(img).unsqueeze(0)
img = img.to(self.device)
return img
def _run(
self,
image_path: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Tuple[Dict[str, float], Dict]:
"""Classify the chest X-ray image for multiple pathologies.
Args:
image_path (str): The path to the chest X-ray image file.
run_manager (Optional[CallbackManagerForToolRun]): The callback manager for the tool run.
Returns:
Tuple[Dict[str, float], Dict]: A tuple containing the classification results
(pathologies and their probabilities from 0 to 1)
and any additional metadata.
Raises:
Exception: If there's an error processing the image or during classification.
"""
try:
img = self._process_image(image_path)
with torch.inference_mode():
preds = self.model(img).cpu()[0]
output = dict(zip(xrv.datasets.default_pathologies, preds.numpy()))
metadata = {
"image_path": image_path,
"analysis_status": "completed",
"note": "Probabilities range from 0 to 1, with higher values indicating higher likelihood of the condition.",
}
return output, metadata
except Exception as e:
return {"error": str(e)}, {
"image_path": image_path,
"analysis_status": "failed",
}
async def _arun(
self,
image_path: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> Tuple[Dict[str, float], Dict]:
"""Asynchronously classify the chest X-ray image for multiple pathologies.
This method currently calls the synchronous version, as the model inference
is not inherently asynchronous. For true asynchronous behavior, consider
using a separate thread or process.
Args:
image_path (str): The path to the chest X-ray image file.
run_manager (Optional[AsyncCallbackManagerForToolRun]): The async callback manager for the tool run.
Returns:
Tuple[Dict[str, float], Dict]: A tuple containing the classification results
(pathologies and their probabilities from 0 to 1)
and any additional metadata.
Raises:
Exception: If there's an error processing the image or during classification.
"""
return self._run(image_path)