File size: 2,816 Bytes
2974476
72ff248
2bdc9ef
2974476
65272a9
 
832ce7b
2974476
2bf9d03
 
a0ba541
72ff248
3d49b4f
 
 
a0ba541
fa9231d
0b2a88c
2bdc9ef
 
 
eae216c
 
 
2bf9d03
2bdc9ef
 
72ff248
9998c92
2bdc9ef
 
 
 
 
65272a9
9698346
a0ba541
c9c43bc
2bdc9ef
9698346
eae216c
2bdc9ef
 
6a9197e
2bdc9ef
 
 
 
 
 
 
 
8327db6
eae216c
2bdc9ef
 
 
 
 
 
 
eae216c
a0ba541
6a9197e
8327db6
 
9698346
4f5fa66
2974476
 
 
4f5fa66
2bdc9ef
 
 
 
 
 
4f5fa66
2974476
2bdc9ef
 
4f5fa66
2629ae5
9698346
6a9197e
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
import torch
from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
import gradio as gr
from PIL import Image
from torchvision.transforms import ToTensor

# Get API token from environment variable
api_token = os.getenv("HF_TOKEN").strip()

# Quantization configuration
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16
)

# Initialize model and tokenizer
model = AutoModel.from_pretrained(
    "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.float16,
    trust_remote_code=True,
    attn_implementation="flash_attention_2",
    token=api_token
)

tokenizer = AutoTokenizer.from_pretrained(
    "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
    trust_remote_code=True,
    token=api_token
)

def analyze_input(image, question):
    try:
        if image is not None:
            # Convert to RGB if image is provided
            image = image.convert('RGB')
        
        # Prepare messages in the format expected by the model
        msgs = [{'role': 'user', 'content': [image, question]}]
        
        # Generate response using the chat method
        response_stream = model.chat(
            image=image,
            msgs=msgs,
            tokenizer=tokenizer,
            sampling=True,
            temperature=0.95,
            stream=True
        )
        
        # Collect the streamed response
        generated_text = ""
        for new_text in response_stream:
            generated_text += new_text
            print(new_text, flush=True, end='')
        
        return {"status": "success", "response": generated_text}
    
    except Exception as e:
        import traceback
        error_trace = traceback.format_exc()
        print(f"Error occurred: {error_trace}")
        return {"status": "error", "message": str(e)}

# Create Gradio interface
demo = gr.Interface(
    fn=analyze_input,
    inputs=[
        gr.Image(type="pil", label="Upload Medical Image"),
        gr.Textbox(
            label="Medical Question",
            placeholder="Give the modality, organ, analysis, abnormalities (if any), treatment (if abnormalities are present)?",
            value="Give the modality, organ, analysis, abnormalities (if any), treatment (if abnormalities are present)?"
        )
    ],
    outputs=gr.JSON(label="Analysis"),
    title="Medical Image Analysis Assistant",
    description="Upload a medical image and ask questions about it. The AI will analyze the image and provide detailed responses."
)

# Launch the Gradio app
if __name__ == "__main__":
    demo.launch(
        share=True,
        server_name="0.0.0.0",
        server_port=7860
    )