Spaces:
Runtime error
Runtime error
from typing import Dict, List, Optional, Tuple, Type, Any | |
from pathlib import Path | |
from pydantic import BaseModel, Field | |
import torch | |
import transformers | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from langchain_core.callbacks import ( | |
AsyncCallbackManagerForToolRun, | |
CallbackManagerForToolRun, | |
) | |
from langchain_core.tools import BaseTool | |
class XRayVQAToolInput(BaseModel): | |
"""Input schema for the CheXagent Tool.""" | |
image_paths: List[str] = Field( | |
..., description="List of paths to chest X-ray images to analyze" | |
) | |
prompt: str = Field(..., description="Question or instruction about the chest X-ray images") | |
max_new_tokens: int = Field( | |
512, description="Maximum number of tokens to generate in the response" | |
) | |
class XRayVQATool(BaseTool): | |
"""Tool that leverages CheXagent for comprehensive chest X-ray analysis.""" | |
name: str = "chest_xray_expert" | |
description: str = ( | |
"A versatile tool for analyzing chest X-rays. " | |
"Can perform multiple tasks including: visual question answering, report generation, " | |
"abnormality detection, comparative analysis, anatomical description, " | |
"and clinical interpretation. Input should be paths to X-ray images " | |
"and a natural language prompt describing the analysis needed." | |
) | |
args_schema: Type[BaseModel] = XRayVQAToolInput | |
return_direct: bool = True | |
cache_dir: Optional[str] = None | |
device: Optional[str] = None | |
dtype: torch.dtype = torch.bfloat16 | |
tokenizer: Optional[AutoTokenizer] = None | |
model: Optional[AutoModelForCausalLM] = None | |
def __init__( | |
self, | |
model_name: str = "StanfordAIMI/CheXagent-2-3b", | |
device: Optional[str] = "cuda", | |
dtype: torch.dtype = torch.bfloat16, | |
cache_dir: Optional[str] = None, | |
**kwargs: Any, | |
) -> None: | |
"""Initialize the XRayVQATool. | |
Args: | |
model_name: Name of the CheXagent model to use | |
device: Device to run model on (cuda/cpu) | |
dtype: Data type for model weights | |
cache_dir: Directory to cache downloaded models | |
**kwargs: Additional arguments | |
""" | |
super().__init__(**kwargs) | |
# Dangerous code, but works for now | |
import transformers | |
original_transformers_version = transformers.__version__ | |
transformers.__version__ = "4.40.0" | |
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
self.dtype = dtype | |
self.cache_dir = cache_dir | |
# Load tokenizer and model | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
cache_dir=cache_dir, | |
) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map=self.device, | |
trust_remote_code=True, | |
cache_dir=cache_dir, | |
) | |
self.model = self.model.to(dtype=self.dtype) | |
self.model.eval() | |
transformers.__version__ = original_transformers_version | |
def _generate_response(self, image_paths: List[str], prompt: str, max_new_tokens: int) -> str: | |
"""Generate response using CheXagent model. | |
Args: | |
image_paths: List of paths to chest X-ray images | |
prompt: Question or instruction about the images | |
max_new_tokens: Maximum number of tokens to generate | |
Returns: | |
str: Model's response | |
""" | |
query = self.tokenizer.from_list_format( | |
[*[{"image": path} for path in image_paths], {"text": prompt}] | |
) | |
conv = [ | |
{"from": "system", "value": "You are a helpful assistant."}, | |
{"from": "human", "value": query}, | |
] | |
input_ids = self.tokenizer.apply_chat_template( | |
conv, add_generation_prompt=True, return_tensors="pt" | |
).to(device=self.device) | |
# Run inference | |
with torch.inference_mode(): | |
output = self.model.generate( | |
input_ids, | |
do_sample=False, | |
num_beams=1, | |
temperature=1.0, | |
top_p=1.0, | |
use_cache=True, | |
max_new_tokens=max_new_tokens, | |
)[0] | |
response = self.tokenizer.decode(output[input_ids.size(1) : -1]) | |
return response | |
def _run( | |
self, | |
image_paths: List[str], | |
prompt: str, | |
max_new_tokens: int = 512, | |
run_manager: Optional[CallbackManagerForToolRun] = None, | |
) -> Tuple[Dict[str, Any], Dict]: | |
"""Execute the chest X-ray analysis. | |
Args: | |
image_paths: List of paths to chest X-ray images | |
prompt: Question or instruction about the images | |
max_new_tokens: Maximum number of tokens to generate | |
run_manager: Optional callback manager | |
Returns: | |
Tuple[Dict[str, Any], Dict]: Output dictionary and metadata dictionary | |
""" | |
try: | |
# Verify image paths | |
for path in image_paths: | |
if not Path(path).is_file(): | |
raise FileNotFoundError(f"Image file not found: {path}") | |
response = self._generate_response(image_paths, prompt, max_new_tokens) | |
output = { | |
"response": response, | |
} | |
metadata = { | |
"image_paths": image_paths, | |
"prompt": prompt, | |
"max_new_tokens": max_new_tokens, | |
"analysis_status": "completed", | |
} | |
return output, metadata | |
except Exception as e: | |
output = {"error": str(e)} | |
metadata = { | |
"image_paths": image_paths, | |
"prompt": prompt, | |
"max_new_tokens": max_new_tokens, | |
"analysis_status": "failed", | |
"error_details": str(e), | |
} | |
return output, metadata | |
async def _arun( | |
self, | |
image_paths: List[str], | |
prompt: str, | |
max_new_tokens: int = 512, | |
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, | |
) -> Tuple[Dict[str, Any], Dict]: | |
"""Async version of _run.""" | |
return self._run(image_paths, prompt, max_new_tokens) | |