from typing import Dict, Optional, Tuple, Type from pathlib import Path import uuid import tempfile import torch from pydantic import BaseModel, Field from diffusers import StableDiffusionPipeline from langchain_core.callbacks import AsyncCallbackManagerForToolRun, CallbackManagerForToolRun from langchain_core.tools import BaseTool class ChestXRayGeneratorInput(BaseModel): """Input schema for the Chest X-Ray Generator Tool.""" prompt: str = Field( ..., description="Description of the medical condition to generate (e.g., 'big left-sided pleural effusion')" ) height: int = Field( 512, description="Height of generated image in pixels" ) width: int = Field( 512, description="Width of generated image in pixels" ) num_inference_steps: int = Field( 75, description="Number of denoising steps (higher = better quality but slower)" ) guidance_scale: float = Field( 4.0, description="How closely to follow the prompt (higher = more faithful but less diverse)" ) class ChestXRayGeneratorTool(BaseTool): """Tool for generating synthetic chest X-ray images using a fine-tuned Stable Diffusion model.""" name: str = "chest_xray_generator" description: str = ( "Generates synthetic chest X-ray images from text descriptions of medical conditions. " "Input: Text description of the medical finding or condition to generate, " "along with optional parameters for image size (height, width), " "quality (num_inference_steps), and prompt adherence (guidance_scale). " "Output: Path to the generated X-ray image and generation metadata." ) args_schema: Type[BaseModel] = ChestXRayGeneratorInput model: StableDiffusionPipeline = None device: torch.device = None temp_dir: Path = None def __init__( self, model_path: str = "/model-weights/roentgen", cache_dir: str = "/model-weights", temp_dir: Optional[str] = None, device: Optional[str] = "cuda", ): """Initialize the chest X-ray generator tool.""" super().__init__() self.device = torch.device(device) if device else "cuda" self.model = StableDiffusionPipeline.from_pretrained(model_path, cache_dir=cache_dir) self.model = self.model.to(torch.float32).to(self.device) self.temp_dir = Path(temp_dir if temp_dir else tempfile.mkdtemp()) self.temp_dir.mkdir(exist_ok=True) def _run( self, prompt: str, num_inference_steps: int = 75, guidance_scale: float = 4.0, height: int = 512, width: int = 512, run_manager: Optional[CallbackManagerForToolRun] = None, ) -> Tuple[Dict[str, str], Dict]: """Generate a chest X-ray image from a text description. Args: prompt: Text description of the medical condition to generate num_inference_steps: Number of denoising steps guidance_scale: How closely to follow the prompt height: Height of generated image in pixels width: Width of generated image in pixels run_manager: Optional callback manager Returns: Tuple[Dict, Dict]: Output dictionary with image path and metadata dictionary """ try: # Generate image generation_output = self.model( [prompt], num_inference_steps=num_inference_steps, height=height, width=width, guidance_scale=guidance_scale ) # Save generated image image_path = self.temp_dir / f"generated_xray_{uuid.uuid4().hex[:8]}.png" generation_output.images[0].save(image_path) output = { "image_path": str(image_path), } metadata = { "prompt": prompt, "num_inference_steps": num_inference_steps, "guidance_scale": guidance_scale, "device": str(self.device), "image_size": (height, width), "analysis_status": "completed", } return output, metadata except Exception as e: return ( {"error": str(e)}, { "prompt": prompt, "analysis_status": "failed", "error_details": str(e), } ) async def _arun( self, prompt: str, num_inference_steps: int = 75, guidance_scale: float = 4.0, height: int = 512, width: int = 512, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> Tuple[Dict[str, str], Dict]: """Async version of _run.""" return self._run(prompt, num_inference_steps, guidance_scale, height, width)