File size: 2,871 Bytes
2974476
72ff248
e8eeeb2
a0ba541
2974476
 
 
832ce7b
2974476
2bf9d03
 
a0ba541
72ff248
3d49b4f
 
 
a0ba541
fa9231d
0b2a88c
c600b9f
e8eeeb2
3d49b4f
 
 
 
2bf9d03
5da7650
9da0a3e
fa9231d
882bd69
72ff248
3d49b4f
2bf9d03
5da7650
9da0a3e
72ff248
9998c92
3d49b4f
a0ba541
e8eeeb2
 
 
2974476
e8eeeb2
2974476
e8eeeb2
c600b9f
 
 
 
 
 
 
9da0a3e
e8eeeb2
c600b9f
 
 
 
 
 
 
9da0a3e
c600b9f
 
9da0a3e
 
e8eeeb2
a0ba541
2974476
e8eeeb2
 
 
 
2974476
 
 
 
a0ba541
2974476
 
 
 
4f5fa66
2974476
 
 
4f5fa66
e8eeeb2
2974476
4f5fa66
2974476
e8eeeb2
c600b9f
9da0a3e
4f5fa66
2629ae5
9da0a3e
3d49b4f
 
 
5da7650
3d49b4f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from PIL import Image
import gradio as gr
import base64
import io

# Get API token from environment variable
api_token = os.getenv("HF_TOKEN").strip()

# Quantization configuration
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16
)

# Load model with revision pinning
model = AutoModelForCausalLM.from_pretrained(
    "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.float16,
    trust_remote_code=True,
    token=api_token,
    revision="main"
)

tokenizer = AutoTokenizer.from_pretrained(
    "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
    trust_remote_code=True,
    token=api_token,
    revision="main"
)

def analyze_input(image_data, question):
    try:
        # Prepare the prompt
        if image_data is not None:
            prompt = f"Given the medical image and the question: {question}\nPlease provide a detailed analysis."
        else:
            prompt = f"Medical question: {question}\nAnswer: "

        # Tokenize input
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
        
        # Prepare model inputs
        model_inputs = {
            "input_ids": input_ids,
            "pixel_values": None  # Set to None for text-only queries
        }
        
        # Generate response
        generation_config = {
            "max_new_tokens": 256,
            "do_sample": True,
            "temperature": 0.7,
            "top_p": 0.9,
        }
        
        outputs = model.generate(
            model_inputs=model_inputs,
            **generation_config
        )
        
        # Decode and clean up response
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Remove the prompt from the response
        if prompt in response:
            response = response[len(prompt):].strip()
        
        return {
            "status": "success",
            "response": response
        }
    except Exception as e:
        return {
            "status": "error",
            "message": str(e)
        }

# Create Gradio interface
demo = gr.Interface(
    fn=analyze_input,
    inputs=[
        gr.Image(type="numpy", label="Medical Image (Optional)"),
        gr.Textbox(label="Question", placeholder="Enter your medical query...")
    ],
    outputs=gr.JSON(label="Analysis"),
    title="Medical Query Analysis",
    description="Ask medical questions. For now, please focus on text-based queries without images.",
    flagging_mode="never"
)

# Launch the interface
demo.launch(
    share=True,
    server_name="0.0.0.0",
    server_port=7860
)