import os import torch from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig import gradio as gr from PIL import Image from torchvision.transforms import ToTensor # Get API token from environment variable api_token = os.getenv("HF_TOKEN").strip() # Quantization configuration bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.float16 ) # Initialize model and tokenizer model = AutoModel.from_pretrained( "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1", quantization_config=bnb_config, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True, attn_implementation="flash_attention_2", token=api_token ) tokenizer = AutoTokenizer.from_pretrained( "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1", trust_remote_code=True, token=api_token ) def analyze_input(image, question): try: if image is not None: # Convert to RGB if image is provided image = image.convert('RGB') # Prepare messages in the format expected by the model msgs = [{'role': 'user', 'content': [image, question]}] # Generate response using the chat method response_stream = model.chat( image=image, msgs=msgs, tokenizer=tokenizer, sampling=True, temperature=0.95, stream=True ) # Collect the streamed response generated_text = "" for new_text in response_stream: generated_text += new_text print(new_text, flush=True, end='') return {"status": "success", "response": generated_text} except Exception as e: import traceback error_trace = traceback.format_exc() print(f"Error occurred: {error_trace}") return {"status": "error", "message": str(e)} # Create Gradio interface demo = gr.Interface( fn=analyze_input, inputs=[ gr.Image(type="pil", label="Upload Medical Image"), gr.Textbox( label="Medical Question", placeholder="Give the modality, organ, analysis, abnormalities (if any), treatment (if abnormalities are present)?", value="Give the modality, organ, analysis, abnormalities (if any), treatment (if abnormalities are present)?" ) ], outputs=gr.JSON(label="Analysis"), title="Medical Image Analysis Assistant", description="Upload a medical image and ask questions about it. The AI will analyze the image and provide detailed responses." ) # Launch the Gradio app if __name__ == "__main__": demo.launch( share=True, server_name="0.0.0.0", server_port=7860 )