File size: 3,058 Bytes
b582acf
 
 
 
8aff603
 
 
 
b582acf
b63f8db
c4d4c92
 
 
 
 
b582acf
 
8aff603
 
 
 
 
 
 
 
 
d605348
40b3e0b
d605348
 
 
 
 
40b3e0b
d605348
1a70836
b63f8db
 
 
 
 
 
 
0a459f9
1a70836
 
9a00230
9e4b9cd
6086049
9a00230
 
 
 
 
 
 
9e4b9cd
 
 
1a70836
9e4b9cd
641194a
 
 
 
 
9a00230
1a70836
56274b2
40b3e0b
 
 
 
 
 
 
556d95d
1a70836
5d29f36
1a70836
 
 
5d29f36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gradio as gr
from PIL import Image
import io
import base64
from huggingface_hub import InferenceClient

# Initialize the Hugging Face Inference Client
client = InferenceClient("microsoft/llava-med-7b-delta")

# Function to encode image as base64
def image_to_base64(image):
    buffered = io.BytesIO()
    image.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
    return img_str

# Function to interact with LLAVA model
def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
    image=None
):
    messages = [{"role": "system", "content": system_message}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    if image:
        # Convert image to base64
        if isinstance(image, Image.Image):
            image_b64 = image_to_base64(image)
            messages.append({"role": "user", "content": "Image uploaded", "image": image_b64})
        else:
            for img in image:
                image_b64 = image_to_base64(img)
                messages.append({"role": "user", "content": "Image uploaded", "image": image_b64})

    messages.append({"role": "user", "content": message})

    try:
        responses = []

        for response in client.chat_completion(
            messages,
            max_tokens=max_tokens,
            stream=True,
            temperature=temperature,
            top_p=top_p,
        ):
            token = response.choices[0].delta.content
            responses.append(token)

        return responses

    except Exception as e:
        error_message = f"Error: {str(e)}"
        return [error_message]


    except Exception as e:
        return [str(e)]

# Debugging print statements
print("Starting Gradio interface setup...")
try:
    # Create a Gradio interface
    demo = gr.Interface(
        fn=respond,
        inputs=[
            gr.Image(label="Upload Medical Image", type="pil"),
            gr.Textbox(label="Message")
        ],
        outputs=gr.Textbox(label="Response", placeholder="Model response will appear here..."),
        title="LLAVA Model - Medical Image and Question",
        description="Upload a medical image and ask a specific question about the image for a medical description.",
        additional_inputs=[
            gr.Textbox(label="System message", value="You are a friendly Chatbot."),
            gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
            gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
            gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
        ]
    )

    # Launch the Gradio interface
    if __name__ == "__main__":
        print("Launching Gradio interface...")
        demo.launch()

except Exception as e:
    print(f"Error during Gradio setup: {str(e)}")