from typing import Any, Dict, Optional, Tuple, Type from pydantic import BaseModel, Field import torch from langchain_core.callbacks import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) from langchain_core.tools import BaseTool from PIL import Image from transformers import ( BertTokenizer, ViTImageProcessor, VisionEncoderDecoderModel, GenerationConfig, ) class ChestXRayInput(BaseModel): """Input for chest X-ray analysis tools. Only supports JPG or PNG images.""" image_path: str = Field( ..., description="Path to the radiology image file, only supports JPG or PNG images" ) class ChestXRayReportGeneratorTool(BaseTool): """Tool that generates comprehensive chest X-ray reports with both findings and impressions. This tool uses two Vision-Encoder-Decoder models (ViT-BERT) trained on CheXpert and MIMIC-CXR datasets to generate structured radiology reports. It automatically generates both detailed findings and impression summaries for each chest X-ray, following standard radiological reporting format. The tool uses: - Findings model: Generates detailed observations of all visible structures - Impression model: Provides concise clinical interpretation and key diagnoses """ name: str = "chest_xray_report_generator" description: str = ( "A tool that analyzes chest X-ray images and generates comprehensive radiology reports " "containing both detailed findings and impression summaries. Input should be the path " "to a chest X-ray image file. Output is a structured report with both detailed " "observations and key clinical conclusions." ) device: Optional[str] = "cuda" args_schema: Type[BaseModel] = ChestXRayInput findings_model: VisionEncoderDecoderModel = None impression_model: VisionEncoderDecoderModel = None findings_tokenizer: BertTokenizer = None impression_tokenizer: BertTokenizer = None findings_processor: ViTImageProcessor = None impression_processor: ViTImageProcessor = None generation_args: Dict[str, Any] = None def __init__(self, cache_dir: str = "/model-weights", device: Optional[str] = "cuda"): """Initialize the ChestXRayReportGeneratorTool with both findings and impression models.""" super().__init__() self.device = torch.device(device) if device else "cuda" # Initialize findings model self.findings_model = VisionEncoderDecoderModel.from_pretrained( "IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir ).eval() self.findings_tokenizer = BertTokenizer.from_pretrained( "IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir ) self.findings_processor = ViTImageProcessor.from_pretrained( "IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir ) # Initialize impression model self.impression_model = VisionEncoderDecoderModel.from_pretrained( "IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir ).eval() self.impression_tokenizer = BertTokenizer.from_pretrained( "IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir ) self.impression_processor = ViTImageProcessor.from_pretrained( "IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir ) # Move models to device self.findings_model = self.findings_model.to(self.device) self.impression_model = self.impression_model.to(self.device) # Default generation arguments self.generation_args = { "num_return_sequences": 1, "max_length": 128, "use_cache": True, "beam_width": 2, } def _process_image( self, image_path: str, processor: ViTImageProcessor, model: VisionEncoderDecoderModel ) -> torch.Tensor: """Process the input image for a specific model. Args: image_path (str): Path to the input image. processor: Image processor for the specific model. model: The model to process the image for. Returns: torch.Tensor: Processed image tensor ready for model input. """ image = Image.open(image_path).convert("RGB") pixel_values = processor(image, return_tensors="pt").pixel_values expected_size = model.config.encoder.image_size actual_size = pixel_values.shape[-1] if expected_size != actual_size: pixel_values = torch.nn.functional.interpolate( pixel_values, size=(expected_size, expected_size), mode="bilinear", align_corners=False, ) pixel_values = pixel_values.to(self.device) return pixel_values def _generate_report_section( self, pixel_values: torch.Tensor, model: VisionEncoderDecoderModel, tokenizer: BertTokenizer ) -> str: """Generate a report section using the specified model. Args: pixel_values: Processed image tensor. model: The model to use for generation. tokenizer: The tokenizer for the model. Returns: str: Generated text for the report section. """ generation_config = GenerationConfig( **{ **self.generation_args, "bos_token_id": model.config.bos_token_id, "eos_token_id": model.config.eos_token_id, "pad_token_id": model.config.pad_token_id, "decoder_start_token_id": tokenizer.cls_token_id, } ) generated_ids = model.generate(pixel_values, generation_config=generation_config) return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] def _run( self, image_path: str, run_manager: Optional[CallbackManagerForToolRun] = None, ) -> Tuple[str, Dict]: """Generate a comprehensive chest X-ray report containing both findings and impression. Args: image_path (str): The path to the chest X-ray image file. run_manager (Optional[CallbackManagerForToolRun]): The callback manager. Returns: Tuple[str, Dict]: A tuple containing the complete report and metadata. """ try: # Process image for both models findings_pixels = self._process_image( image_path, self.findings_processor, self.findings_model ) impression_pixels = self._process_image( image_path, self.impression_processor, self.impression_model ) # Generate both sections with torch.inference_mode(): findings_text = self._generate_report_section( findings_pixels, self.findings_model, self.findings_tokenizer ) impression_text = self._generate_report_section( impression_pixels, self.impression_model, self.impression_tokenizer ) # Combine into formatted report report = ( "CHEST X-RAY REPORT\n\n" f"FINDINGS:\n{findings_text}\n\n" f"IMPRESSION:\n{impression_text}" ) metadata = { "image_path": image_path, "analysis_status": "completed", "sections_generated": ["findings", "impression"], } return report, metadata except Exception as e: return f"Error generating report: {str(e)}", { "image_path": image_path, "analysis_status": "failed", "error": str(e), } async def _arun( self, image_path: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> Tuple[str, Dict]: """Asynchronously generate a comprehensive chest X-ray report.""" return self._run(image_path)