File size: 6,258 Bytes
6079c6e
d2c3421
b35805b
9a97411
 
6079c6e
9a97411
 
6079c6e
9a97411
 
8cd3c65
9a97411
 
6079c6e
9a97411
29cb53e
 
 
9a97411
6079c6e
9a97411
 
35446dd
9a97411
 
 
6079c6e
9a97411
 
 
 
 
 
 
 
 
 
 
 
 
 
25ef0fa
b78d721
 
f804d88
9a97411
 
 
 
 
ae66ad0
9a97411
8cd3c65
 
 
 
9a97411
25ef0fa
8cd3c65
9a97411
 
8cd3c65
9a97411
 
25ef0fa
8cd3c65
25ef0fa
 
9a97411
25ef0fa
8cd3c65
 
f804d88
25ef0fa
9a97411
 
f804d88
 
 
 
 
25ef0fa
9a97411
25ef0fa
 
9a97411
b35805b
25ef0fa
9a97411
25ef0fa
d2c3421
f804d88
d2c3421
9a97411
d2c3421
 
 
 
f804d88
d2c3421
 
9a97411
25ef0fa
d2c3421
 
25ef0fa
 
 
9a97411
25ef0fa
 
f804d88
25ef0fa
8cd3c65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a97411
8cd3c65
9a97411
8cd3c65
 
 
9a97411
8cd3c65
 
9a97411
8cd3c65
 
 
 
 
 
 
 
 
25ef0fa
f804d88
9a97411
29cb53e
 
 
 
 
f804d88
9a97411
f804d88
9a97411
 
f804d88
9a97411
29cb53e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import gradio as gr
import os
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread

# Set an environment variable
HF_TOKEN = os.environ.get("HF_TOKEN", None)

DESCRIPTION = '''
<div>
  <h1 style="text-align: center;">A.I. Healthcare</h1>
</div>
'''

LICENSE = """
<p>
This Health Assistant is designed to provide helpful healthcare information; however, it may make mistakes and is not designed to replace professional medical care. It is not intended to diagnose any condition or disease. Always consult with a qualified healthcare provider for any medical concerns.
</p>
"""

PLACEHOLDER = """
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
   <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">A.I. Healthcare</h1>
   <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p>
</div>
"""

css = """
h1 {
  text-align: center;
  display: block;
}

#duplicate-button {
  margin: auto;
  color: white;
  background: #1565c0;
  border-radius: 100vh;
}
"""

# Load the tokenizer and model with the updated model name
tokenizer = AutoTokenizer.from_pretrained("reedmayhew/HealthCare-Reasoning-Assistant-Llama-3.1-8B-HF", device_map="cuda")
model = AutoModelForCausalLM.from_pretrained("reedmayhew/HealthCare-Reasoning-Assistant-Llama-3.1-8B-HF", device_map="cuda")

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

@spaces.GPU(duration=60)
def chat_llama3_8b(message: str, 
                   history: list, 
                   temperature: float, 
                   max_new_tokens: int,
                   confirm: bool) -> str:
    """
    Generate a streaming response using the Healthcare-Reasoning-Assistant-Llama-3.1-8B-HF model.
    
    Args:
        message (str): The input message.
        history (list): The conversation history.
        temperature (float): The temperature for generating the response.
        max_new_tokens (int): The maximum number of new tokens to generate.
        confirm (bool): Whether the user has confirmed the usage disclaimer.
        
    Yields:
        str: The generated response, streamed token-by-token.
    """
    # Ensure the user has confirmed the disclaimer
    if not confirm:
        return "⚠️ You must confirm that you meet the usage requirements before sending a message."
    
    # Prepare the conversation history for the model input
    conversation = []
    for user, assistant in history:
        conversation.extend([
            {"role": "user", "content": user}, 
            {"role": "assistant", "content": assistant}
        ])
    
    # Append the current user message
    conversation.append({"role": "user", "content": message})
    
    # Convert the conversation into input ids using the chat template
    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
    
    # Set up the streamer to stream text output
    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    
    generate_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=temperature,
        eos_token_id=terminators,
    )
    
    if temperature == 0:
        generate_kwargs['do_sample'] = False
        
    # Launch the generation in a separate thread
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()
    
    full_response = ""
    # Simply stream each token as it comes from the model
    for text in streamer:
        full_response += text
        yield text

    # Save the full response (for context in the conversation history)
    history.append((message, full_response))

# Custom JavaScript to disable the send button until confirmation is given.
CUSTOM_JS = """
<script>
document.addEventListener("DOMContentLoaded", function() {
    const interval = setInterval(() => {
        const checkbox = document.querySelector('input[type="checkbox"][aria-label*="I hereby confirm that I am at least 18 years of age"]');
        const sendButton = document.querySelector('button[title="Send"]');
        if (checkbox && sendButton) {
            sendButton.disabled = !checkbox.checked;
            checkbox.addEventListener('change', function() {
                sendButton.disabled = !checkbox.checked;
            });
            clearInterval(interval);
        }
    }, 500);
});
</script>
"""

with gr.Blocks(css=css, title="A.I. Healthcare") as demo:
    gr.Markdown(DESCRIPTION)
    gr.HTML(CUSTOM_JS)
    
    chat_interface = gr.ChatInterface(
        fn=chat_llama3_8b,
        title="A.I. Healthcare Chat",
        chatbot=gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Conversation'),
        additional_inputs=[
            gr.Checkbox(
                value=False,
                label=("I hereby confirm that I am at least 18 years of age (or accompanied by a legal guardian "
                       "who is at least 18 years old), understand that the information provided by this service "
                       "is for informational purposes only and is not intended to diagnose or treat any medical condition, "
                       "and acknowledge that I am solely responsible for verifying any information provided."),
                elem_id="age_confirm_checkbox"
            ),
            gr.Slider(minimum=0.6, maximum=0.6, step=0.1, value=0.6, label="Temperature", visible=False),
            gr.Slider(minimum=128, maximum=4096, step=64, value=1024, label="Max new tokens", visible=False),
        ],
        examples=[
            ['What are the common symptoms of diabetes?'],
            ['How can I manage high blood pressure with lifestyle changes?'],
            ['What nutritional advice can help improve heart health?'],
            ['Can you explain the benefits of regular exercise for mental well-being?'],
            ['What should I know about the side effects of common medications?']
        ],
        cache_examples=False,
    )
    
    gr.Markdown(LICENSE)

if __name__ == "__main__":
    demo.launch()