File size: 2,660 Bytes
2974476
72ff248
 
a0ba541
2974476
 
 
832ce7b
2974476
2bf9d03
 
a0ba541
72ff248
3d49b4f
 
 
a0ba541
fa9231d
0b2a88c
2974476
72ff248
3d49b4f
 
 
 
2bf9d03
 
fa9231d
882bd69
72ff248
3d49b4f
2bf9d03
 
72ff248
9998c92
3d49b4f
a0ba541
2974476
 
 
 
 
 
 
 
 
 
 
 
 
 
a0ba541
 
 
 
 
 
 
2974476
a0ba541
2974476
 
 
 
 
a0ba541
2974476
 
 
 
4f5fa66
2974476
 
 
4f5fa66
3d49b4f
2974476
4f5fa66
2974476
 
 
 
4f5fa66
2629ae5
2974476
3d49b4f
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
import torch
from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
from PIL import Image
import gradio as gr
import base64
import io

# Get API token from environment variable
api_token = os.getenv("HF_TOKEN").strip()

# Quantization configuration
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16
)

# Load model
model = AutoModel.from_pretrained(
    "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.float16,
    trust_remote_code=True,
    token=api_token
)

tokenizer = AutoTokenizer.from_pretrained(
    "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
    trust_remote_code=True,
    token=api_token
)

def analyze_input(image_data, question):
    try:
        # Handle base64 image if provided
        if isinstance(image_data, str) and image_data.startswith('data:image'):
            # Extract base64 data after the comma
            base64_data = image_data.split(',')[1]
            image_bytes = base64.b64decode(base64_data)
            image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
        # Handle direct image input
        elif image_data is not None:
            image = Image.fromarray(image_data).convert('RGB')
        else:
            image = None

        # Process with or without image
        if image is not None:
            inputs = model.prepare_inputs_for_generation(
                input_ids=tokenizer(question, return_tensors="pt").input_ids,
                images=[image]
            )
        else:
            inputs = tokenizer(question, return_tensors="pt")

        outputs = model.generate(**inputs, max_new_tokens=256)
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        return {
            "status": "success",
            "response": response
        }
    except Exception as e:
        return {
            "status": "error",
            "message": str(e)
        }

# Create Gradio interface
demo = gr.Interface(
    fn=analyze_input,
    inputs=[
        gr.Image(type="numpy", label="Medical Image"),  # Removed optional parameter
        gr.Textbox(label="Question", placeholder="Enter your medical query...")
    ],
    outputs=gr.JSON(label="Analysis"),
    title="Bio-Medical MultiModal Analysis",
    description="Ask questions with or without an image",
    allow_flagging="never",
)

# Launch with API access enabled
demo.launch(
    share=True,
    server_name="0.0.0.0",
    server_port=7860,
    enable_queue=True
)