reedmayhew commited on
Commit
f98d1cf
·
verified ·
1 Parent(s): a14d521

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -67
app.py CHANGED
@@ -9,7 +9,7 @@ HF_TOKEN = os.environ.get("HF_TOKEN", None)
9
 
10
  DESCRIPTION = '''
11
  <div>
12
- <h1 style="text-align: center;">A.I. Healthcare</h1>
13
  </div>
14
  '''
15
 
@@ -40,7 +40,7 @@ h1 {
40
  }
41
  """
42
 
43
- # Load the tokenizer and model with the updated model name
44
  tokenizer = AutoTokenizer.from_pretrained("reedmayhew/HealthCare-Reasoning-Assistant-Llama-3.1-8B-HF", device_map="cuda")
45
  model = AutoModelForCausalLM.from_pretrained("reedmayhew/HealthCare-Reasoning-Assistant-Llama-3.1-8B-HF", device_map="cuda")
46
 
@@ -51,28 +51,21 @@ terminators = [
51
 
52
  @spaces.GPU(duration=60)
53
  def chat_llama3_8b(message: str,
54
- history: list,
55
- temperature: float,
56
- max_new_tokens: int,
57
- confirm: bool) -> str:
58
  """
59
- Generate a streaming response using the Healthcare-Reasoning-Assistant-Llama-3.1-8B-HF model.
60
-
61
  Args:
62
  message (str): The input message.
63
- history (list): The conversation history.
64
  temperature (float): The temperature for generating the response.
65
  max_new_tokens (int): The maximum number of new tokens to generate.
66
- confirm (bool): Whether the user has confirmed the usage disclaimer.
67
-
68
- Yields:
69
- str: The generated response, streamed token-by-token.
70
  """
71
- # Ensure the user has confirmed the disclaimer
72
- if not confirm:
73
- return "⚠️ You must confirm that you meet the usage requirements before sending a message."
74
 
75
- # Prepare the conversation history for the model input
76
  conversation = []
77
  for user, assistant in history:
78
  conversation.extend([
@@ -80,15 +73,14 @@ def chat_llama3_8b(message: str,
80
  {"role": "assistant", "content": assistant}
81
  ])
82
 
83
- # Append the current user message
84
  conversation.append({"role": "user", "content": message})
85
-
86
- # Convert the conversation into input ids using the chat template
87
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
88
 
89
- # Set up the streamer to stream text output
90
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
91
-
92
  generate_kwargs = dict(
93
  input_ids=input_ids,
94
  streamer=streamer,
@@ -101,57 +93,50 @@ def chat_llama3_8b(message: str,
101
  if temperature == 0:
102
  generate_kwargs['do_sample'] = False
103
 
104
- # Launch the generation in a separate thread
105
  t = Thread(target=model.generate, kwargs=generate_kwargs)
106
  t.start()
107
-
108
- full_response = ""
109
- # Simply stream each token as it comes from the model
 
 
 
 
110
  for text in streamer:
111
- full_response += text
112
- yield text
113
-
114
- # Save the full response (for context in the conversation history)
115
- history.append((message, full_response))
116
-
117
- # Custom JavaScript to disable the send button until confirmation is given.
118
- CUSTOM_JS = """
119
- <script>
120
- document.addEventListener("DOMContentLoaded", function() {
121
- const interval = setInterval(() => {
122
- const checkbox = document.querySelector('input[type="checkbox"][aria-label*="I hereby confirm that I am at least 18 years of age"]');
123
- const sendButton = document.querySelector('button[title="Send"]');
124
- if (checkbox && sendButton) {
125
- sendButton.disabled = !checkbox.checked;
126
- checkbox.addEventListener('change', function() {
127
- sendButton.disabled = !checkbox.checked;
128
- });
129
- clearInterval(interval);
130
- }
131
- }, 500);
132
- });
133
- </script>
134
- """
135
 
136
- with gr.Blocks(css=css, title="A.I. Healthcare") as demo:
137
- gr.Markdown(DESCRIPTION)
138
- gr.HTML(CUSTOM_JS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
- chat_interface = gr.ChatInterface(
 
141
  fn=chat_llama3_8b,
142
- title="A.I. Healthcare Chat",
143
- chatbot=gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Conversation'),
 
144
  additional_inputs=[
145
- gr.Checkbox(
146
- value=False,
147
- label=("I hereby confirm that I am at least 18 years of age (or accompanied by a legal guardian "
148
- "who is at least 18 years old), understand that the information provided by this service "
149
- "is for informational purposes only and is not intended to diagnose or treat any medical condition, "
150
- "and acknowledge that I am solely responsible for verifying any information provided."),
151
- elem_id="age_confirm_checkbox"
152
- ),
153
- gr.Slider(minimum=0.6, maximum=0.6, step=0.1, value=0.6, label="Temperature", visible=False),
154
- gr.Slider(minimum=128, maximum=4096, step=64, value=1024, label="Max new tokens", visible=False),
155
  ],
156
  examples=[
157
  ['What are the common symptoms of diabetes?'],
@@ -166,4 +151,4 @@ with gr.Blocks(css=css, title="A.I. Healthcare") as demo:
166
  gr.Markdown(LICENSE)
167
 
168
  if __name__ == "__main__":
169
- demo.launch()
 
9
 
10
  DESCRIPTION = '''
11
  <div>
12
+ <h1 style="text-align: center;">A.I. Healthcare</h1>
13
  </div>
14
  '''
15
 
 
40
  }
41
  """
42
 
43
+ # Load the tokenizer and model
44
  tokenizer = AutoTokenizer.from_pretrained("reedmayhew/HealthCare-Reasoning-Assistant-Llama-3.1-8B-HF", device_map="cuda")
45
  model = AutoModelForCausalLM.from_pretrained("reedmayhew/HealthCare-Reasoning-Assistant-Llama-3.1-8B-HF", device_map="cuda")
46
 
 
51
 
52
  @spaces.GPU(duration=60)
53
  def chat_llama3_8b(message: str,
54
+ history: list,
55
+ temperature: float,
56
+ max_new_tokens: int
57
+ ) -> str:
58
  """
59
+ Generate a streaming response using the llama3-8b model.
 
60
  Args:
61
  message (str): The input message.
62
+ history (list): The conversation history used by ChatInterface.
63
  temperature (float): The temperature for generating the response.
64
  max_new_tokens (int): The maximum number of new tokens to generate.
65
+ Returns:
66
+ str: The generated response.
 
 
67
  """
 
 
 
68
 
 
69
  conversation = []
70
  for user, assistant in history:
71
  conversation.extend([
 
73
  {"role": "assistant", "content": assistant}
74
  ])
75
 
76
+ # Ensure the model starts with "<think>"
77
  conversation.append({"role": "user", "content": message})
78
+ conversation.append({"role": "assistant", "content": "<think> "}) # Force <think> at start
79
+
80
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
81
 
 
82
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
83
+
84
  generate_kwargs = dict(
85
  input_ids=input_ids,
86
  streamer=streamer,
 
93
  if temperature == 0:
94
  generate_kwargs['do_sample'] = False
95
 
 
96
  t = Thread(target=model.generate, kwargs=generate_kwargs)
97
  t.start()
98
+
99
+ outputs = []
100
+ buffer = ""
101
+ think_detected = False
102
+ thinking_message_sent = False
103
+ full_response = "" # Store the full assistant response
104
+
105
  for text in streamer:
106
+ buffer += text
107
+ full_response += text # Store raw assistant response (includes <think>)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
+ # Send the "thinking" message once text starts generating
110
+ if not thinking_message_sent:
111
+ thinking_message_sent = True
112
+ yield "A.I. Healthcare is Thinking...\n\n"
113
+
114
+ # Wait until </think> is detected before streaming output
115
+ if not think_detected:
116
+ if "</think>" in buffer:
117
+ think_detected = True
118
+ buffer = buffer.split("</think>", 1)[1] # Remove <think> section
119
+ else:
120
+ outputs.append(text)
121
+ yield "".join(outputs)
122
+
123
+ # Store the full response (including <think>) in history, but only show the user the cleaned response
124
+ history.append((message, full_response)) # Full assistant response saved for context
125
+
126
+ # Gradio block
127
+ chatbot = gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface')
128
+
129
+ with gr.Blocks(fill_height=True, css=css) as demo:
130
 
131
+ gr.Markdown(DESCRIPTION)
132
+ gr.ChatInterface(
133
  fn=chat_llama3_8b,
134
+ chatbot=chatbot,
135
+ fill_height=True,
136
+ additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
137
  additional_inputs=[
138
+ gr.Slider(minimum=0.6, maximum=0.6, step=0.1, value=0.6, label="Temperature", render=False),
139
+ gr.Slider(minimum=128, maximum=4096, step=64, value=1024, label="Max new tokens", render=False),
 
 
 
 
 
 
 
 
140
  ],
141
  examples=[
142
  ['What are the common symptoms of diabetes?'],
 
151
  gr.Markdown(LICENSE)
152
 
153
  if __name__ == "__main__":
154
+ demo.launch()