reedmayhew commited on
Commit
25ef0fa
·
verified ·
1 Parent(s): b7ef472

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -42
app.py CHANGED
@@ -40,7 +40,7 @@ h1 {
40
  }
41
  """
42
 
43
- # Load the tokenizer and model
44
  tokenizer = AutoTokenizer.from_pretrained("reedmayhew/HealthCare-Reasoning-Assistant-Llama-3.1-8B-HF", device_map="cuda")
45
  model = AutoModelForCausalLM.from_pretrained("reedmayhew/HealthCare-Reasoning-Assistant-Llama-3.1-8B-HF", device_map="cuda")
46
 
@@ -56,22 +56,23 @@ def chat_llama3_8b(message: str,
56
  max_new_tokens: int,
57
  confirm: bool) -> str:
58
  """
59
- Generate a streaming response using the llama3-8b model.
60
 
61
  Args:
62
  message (str): The input message.
63
  history (list): The conversation history.
64
  temperature (float): The temperature for generating the response.
65
  max_new_tokens (int): The maximum number of new tokens to generate.
66
- confirm (bool): Whether the user has confirmed the age/disclaimer.
67
 
68
- Returns:
69
- str: The generated response.
70
  """
71
- # If the confirmation checkbox is not checked, return a short message immediately.
72
  if not confirm:
73
  return "⚠️ You must confirm that you meet the usage requirements before sending a message."
74
 
 
75
  conversation = []
76
  for user, assistant in history:
77
  conversation.extend([
@@ -79,14 +80,15 @@ def chat_llama3_8b(message: str,
79
  {"role": "assistant", "content": assistant}
80
  ])
81
 
82
- # Ensure the model starts with "<think>"
83
  conversation.append({"role": "user", "content": message})
84
- conversation.append({"role": "assistant", "content": "<think> "}) # Force <think> at start
85
-
86
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
87
 
 
88
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
89
-
90
  generate_kwargs = dict(
91
  input_ids=input_ids,
92
  streamer=streamer,
@@ -99,47 +101,25 @@ def chat_llama3_8b(message: str,
99
  if temperature == 0:
100
  generate_kwargs['do_sample'] = False
101
 
 
102
  t = Thread(target=model.generate, kwargs=generate_kwargs)
103
  t.start()
104
-
105
- outputs = []
106
- buffer = ""
107
- think_detected = False
108
- thinking_message_sent = False
109
- full_response = "" # Store the full assistant response
110
-
111
  for text in streamer:
112
- buffer += text
113
- full_response += text # Store raw assistant response (includes <think>)
114
-
115
- # Send the "thinking" message once text starts generating
116
- if not thinking_message_sent:
117
- thinking_message_sent = True
118
- yield "A.I. Healthcare is Thinking... Please wait...\n\n"
119
-
120
- # Wait until </think> is detected before streaming output
121
- if not think_detected:
122
- print(buffer)
123
- if "</think>" in buffer:
124
- think_detected = True
125
- buffer = buffer.split("</think>", 1)[1] # Remove <think> section
126
- else:
127
- outputs.append(text)
128
- yield "".join(outputs)
129
 
130
- # Store the full response (including <think>) in history for context
131
  history.append((message, full_response))
132
 
133
  # Custom JavaScript to disable the send button until confirmation is given.
134
- # (The JS waits for the checkbox with a label containing the specified text and then monitors its state.)
135
  CUSTOM_JS = """
136
  <script>
137
  document.addEventListener("DOMContentLoaded", function() {
138
- // Poll for the confirmation checkbox and the send button inside the ChatInterface.
139
  const interval = setInterval(() => {
140
- // The checkbox is rendered as an <input type="checkbox"> with an associated label.
141
  const checkbox = document.querySelector('input[type="checkbox"][aria-label*="I hereby confirm that I am at least 18 years of age"]');
142
- // The send button might be a <button> element with a title or specific text. Adjust the selector as needed.
143
  const sendButton = document.querySelector('button[title="Send"]');
144
  if (checkbox && sendButton) {
145
  sendButton.disabled = !checkbox.checked;
@@ -155,10 +135,8 @@ document.addEventListener("DOMContentLoaded", function() {
155
 
156
  with gr.Blocks(css=css, title="A.I. Healthcare") as demo:
157
  gr.Markdown(DESCRIPTION)
158
- # Inject the custom JavaScript.
159
  gr.HTML(CUSTOM_JS)
160
 
161
- # The ChatInterface below now includes additional inputs: the confirmation checkbox and the parameter sliders.
162
  chat_interface = gr.ChatInterface(
163
  fn=chat_llama3_8b,
164
  title="A.I. Healthcare Chat",
@@ -173,7 +151,7 @@ with gr.Blocks(css=css, title="A.I. Healthcare") as demo:
173
  elem_id="age_confirm_checkbox"
174
  ),
175
  gr.Slider(minimum=0.6, maximum=0.6, step=0.1, value=0.6, label="Temperature", visible=False),
176
- gr.Slider(minimum=1024, maximum=4096, step=128, value=2048, label="Max new tokens", visible=False),
177
  ],
178
  examples=[
179
  ['What are the common symptoms of diabetes?'],
@@ -183,6 +161,7 @@ with gr.Blocks(css=css, title="A.I. Healthcare") as demo:
183
  ['What should I know about the side effects of common medications?']
184
  ],
185
  cache_examples=False,
 
186
  )
187
 
188
  gr.Markdown(LICENSE)
 
40
  }
41
  """
42
 
43
+ # Load the tokenizer and model with the updated model name
44
  tokenizer = AutoTokenizer.from_pretrained("reedmayhew/HealthCare-Reasoning-Assistant-Llama-3.1-8B-HF", device_map="cuda")
45
  model = AutoModelForCausalLM.from_pretrained("reedmayhew/HealthCare-Reasoning-Assistant-Llama-3.1-8B-HF", device_map="cuda")
46
 
 
56
  max_new_tokens: int,
57
  confirm: bool) -> str:
58
  """
59
+ Generate a streaming response using the Healthcare-Reasoning-Assistant-Llama-3.1-8B-HF model.
60
 
61
  Args:
62
  message (str): The input message.
63
  history (list): The conversation history.
64
  temperature (float): The temperature for generating the response.
65
  max_new_tokens (int): The maximum number of new tokens to generate.
66
+ confirm (bool): Whether the user has confirmed the usage disclaimer.
67
 
68
+ Yields:
69
+ str: The generated response, streamed token-by-token.
70
  """
71
+ # Ensure the user has confirmed the disclaimer
72
  if not confirm:
73
  return "⚠️ You must confirm that you meet the usage requirements before sending a message."
74
 
75
+ # Prepare the conversation history for the model input
76
  conversation = []
77
  for user, assistant in history:
78
  conversation.extend([
 
80
  {"role": "assistant", "content": assistant}
81
  ])
82
 
83
+ # Append the current user message
84
  conversation.append({"role": "user", "content": message})
85
+
86
+ # Convert the conversation into input ids using the chat template
87
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
88
 
89
+ # Set up the streamer to stream text output
90
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
91
+
92
  generate_kwargs = dict(
93
  input_ids=input_ids,
94
  streamer=streamer,
 
101
  if temperature == 0:
102
  generate_kwargs['do_sample'] = False
103
 
104
+ # Launch the generation in a separate thread
105
  t = Thread(target=model.generate, kwargs=generate_kwargs)
106
  t.start()
107
+
108
+ full_response = ""
109
+ # Simply stream each token as it comes from the model
 
 
 
 
110
  for text in streamer:
111
+ full_response += text
112
+ yield text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
+ # Save the full response (for context in the conversation history)
115
  history.append((message, full_response))
116
 
117
  # Custom JavaScript to disable the send button until confirmation is given.
 
118
  CUSTOM_JS = """
119
  <script>
120
  document.addEventListener("DOMContentLoaded", function() {
 
121
  const interval = setInterval(() => {
 
122
  const checkbox = document.querySelector('input[type="checkbox"][aria-label*="I hereby confirm that I am at least 18 years of age"]');
 
123
  const sendButton = document.querySelector('button[title="Send"]');
124
  if (checkbox && sendButton) {
125
  sendButton.disabled = !checkbox.checked;
 
135
 
136
  with gr.Blocks(css=css, title="A.I. Healthcare") as demo:
137
  gr.Markdown(DESCRIPTION)
 
138
  gr.HTML(CUSTOM_JS)
139
 
 
140
  chat_interface = gr.ChatInterface(
141
  fn=chat_llama3_8b,
142
  title="A.I. Healthcare Chat",
 
151
  elem_id="age_confirm_checkbox"
152
  ),
153
  gr.Slider(minimum=0.6, maximum=0.6, step=0.1, value=0.6, label="Temperature", visible=False),
154
+ gr.Slider(minimum=128, maximum=4096, step=64, value=1024, label="Max new tokens", visible=False),
155
  ],
156
  examples=[
157
  ['What are the common symptoms of diabetes?'],
 
161
  ['What should I know about the side effects of common medications?']
162
  ],
163
  cache_examples=False,
164
+ allow_screenshot=False,
165
  )
166
 
167
  gr.Markdown(LICENSE)