File size: 3,557 Bytes
8e90fc6
72ff248
8e90fc6
 
 
b37e8c8
 
 
 
 
 
 
 
 
832ce7b
8e90fc6
 
2bf9d03
8e90fc6
 
 
 
 
 
 
0b2a88c
b37e8c8
 
 
 
 
 
 
 
 
 
 
 
 
8e90fc6
 
 
b37e8c8
8e90fc6
9998c92
8e90fc6
 
 
 
 
65272a9
8e90fc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f5fa66
8e90fc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2629ae5
8e90fc6
6a9197e
b37e8c8
 
 
 
 
8e90fc6
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import torch
from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
import gradio as gr
from PIL import Image

# First, let's check if flash-attn is installed
try:
    import flash_attn
    FLASH_ATTN_AVAILABLE = True
except ImportError:
    FLASH_ATTN_AVAILABLE = False
    print("Flash Attention is not installed. Using default attention mechanism.")
    print("To install Flash Attention, run: pip install flash-attn --no-build-isolation")

# Get API token from environment variable
api_token = os.getenv("HF_TOKEN").strip()

# Quantization configuration
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16
)

# Initialize model with conditional Flash Attention
model_args = {
    "quantization_config": bnb_config,
    "device_map": "auto",
    "torch_dtype": torch.float16,
    "trust_remote_code": True,
    "token": api_token
}

# Only add flash attention if available
if FLASH_ATTN_AVAILABLE:
    model_args["attn_implementation"] = "flash_attention_2"

# Initialize model and tokenizer
model = AutoModel.from_pretrained(
    "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
    **model_args
)

tokenizer = AutoTokenizer.from_pretrained(
    "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
    trust_remote_code=True,
    token=api_token
)

def analyze_input(image, question):
    try:
        if image is not None:
            # Convert to RGB if image is provided
            image = image.convert('RGB')
        
        # Prepare messages in the format expected by the model
        msgs = [{'role': 'user', 'content': [image, question]}]
        
        # Generate response using the chat method
        response_stream = model.chat(
            image=image,
            msgs=msgs,
            tokenizer=tokenizer,
            sampling=True,
            temperature=0.95,
            stream=True
        )
        
        # Collect the streamed response
        generated_text = ""
        for new_text in response_stream:
            generated_text += new_text
            print(new_text, flush=True, end='')
        
        return {"status": "success", "response": generated_text}
    
    except Exception as e:
        import traceback
        error_trace = traceback.format_exc()
        print(f"Error occurred: {error_trace}")
        return {"status": "error", "message": str(e)}

# Create Gradio interface
demo = gr.Interface(
    fn=analyze_input,
    inputs=[
        gr.Image(type="pil", label="Upload Medical Image"),
        gr.Textbox(
            label="Medical Question",
            placeholder="Give the modality, organ, analysis, abnormalities (if any), treatment (if abnormalities are present)?",
            value="Give the modality, organ, analysis, abnormalities (if any), treatment (if abnormalities are present)?"
        )
    ],
    outputs=gr.JSON(label="Analysis"),
    title="Medical Image Analysis Assistant",
    description="Upload a medical image and ask questions about it. The AI will analyze the image and provide detailed responses."
)

# Launch the Gradio app
if __name__ == "__main__":
    # Print installation instructions if Flash Attention is not available
    if not FLASH_ATTN_AVAILABLE:
        print("\nTo enable Flash Attention 2 for better performance, please install it using:")
        print("pip install flash-attn --no-build-isolation")
    
    demo.launch(
        share=True,
        server_name="0.0.0.0",
        server_port=7860
    )