Spaces:
Sleeping
Sleeping
import gradio as gr | |
from typing import Iterator, List, Tuple | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from peft import PeftConfig, PeftModel | |
base_model = "mistralai/Mistral-7B-Instruct-v0.2" | |
adapter = "GRMenon/mental-health-mistral-7b-instructv0.2-finetuned-V2" | |
# Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained( | |
base_model, | |
add_bos_token=True, | |
trust_remote_code=True, | |
padding_side='left' | |
) | |
# Create peft model using base_model and finetuned adapter | |
config = PeftConfig.from_pretrained(adapter) | |
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, | |
load_in_4bit=True, | |
device_map='auto', | |
torch_dtype='auto') | |
model = PeftModel.from_pretrained(model, adapter) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
model.eval() | |
DEFAULT_SYSTEM_PROMPT = "You are Phoenix AI Healthcare. You are professional, you are polite, give only truthful information and are based on the Mistral-7B model from Mistral AI about Healtcare and Wellness. You can communicate in different languages equally well." | |
MAX_MAX_NEW_TOKENS = 4096 | |
DEFAULT_MAX_NEW_TOKENS = 256 | |
MAX_INPUT_TOKEN_LENGTH = 4000 | |
DESCRIPTION = """ | |
# Simple Healthcare Chatbot | |
### Powered by Mistral-7B with Healthcare Fine-Tuning | |
""" | |
def clear_and_save_textbox(message: str) -> tuple[str, str]: | |
return "", message | |
def display_input(message: str, history: list[tuple[str, str]]) -> list[tuple[str, str]]: | |
history.append((message, "")) | |
return history | |
def delete_prev_fn(history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]: | |
try: | |
message, _ = history.pop() | |
except IndexError: | |
message = "" | |
return history, message or "" | |
def generate( | |
message: str, | |
history_with_input: list[tuple[str, str]], | |
system_prompt: str, | |
max_new_tokens: int, | |
temperature: float, | |
top_p: float, | |
top_k: int, | |
) -> Iterator[list[tuple[str, str]]]: | |
if max_new_tokens > MAX_MAX_NEW_TOKENS: | |
raise ValueError("Max new tokens exceeded") | |
history = history_with_input[:-1] | |
conversation = [{"role": "system", "content": system_prompt}] + \ | |
[{"role": "user", "content": user_input} for user_input, _ in history] + \ | |
[{"role": "user", "content": message}] | |
input_ids = tokenizer.apply_chat_template(conversation=conversation, | |
tokenize=True, | |
add_generation_prompt=True, | |
return_tensors='pt').to(device) | |
output_ids = model.generate(input_ids=input_ids, max_new_tokens=max_new_tokens, | |
do_sample=True, pad_token_id=tokenizer.pad_token_id) | |
response = tokenizer.batch_decode(output_ids.detach().cpu().numpy(), skip_special_tokens=True) | |
response_text = response[0] | |
yield history + [(message, response_text)] | |
def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None: | |
input_token_length = len(tokenizer.encode(message)) + sum(len(tokenizer.encode(msg)) for msg, _ in chat_history) | |
if input_token_length > MAX_INPUT_TOKEN_LENGTH: | |
raise gr.Error(f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.") | |
with gr.Blocks(css="./styles/style.css") as demo: # Link to CSS file | |
gr.Markdown(DESCRIPTION) | |
gr.Button("Duplicate Space for private use", elem_id="duplicate-button") | |
with gr.Group(): | |
chatbot = gr.Chatbot(label="Chat with Healthcare AI") | |
with gr.Row(): | |
textbox = gr.Textbox( | |
container=False, | |
show_label=False, | |
placeholder="Ask me anything about Healthcare and Wellness...", | |
scale=10, | |
) | |
submit_button = gr.Button("Submit", variant="primary", scale=1, min_width=0) | |
with gr.Row(): | |
retry_button = gr.Button('π Retry', variant='secondary') | |
undo_button = gr.Button('β©οΈ Undo', variant='secondary') | |
clear_button = gr.Button('ποΈ Clear', variant='secondary') | |
saved_input = gr.State() | |
with gr.Accordion(label="βοΈ Advanced options", open=False): | |
system_prompt = gr.Textbox( | |
label="System prompt", | |
value=DEFAULT_SYSTEM_PROMPT, | |
lines=5, | |
interactive=False, | |
) | |
max_new_tokens = gr.Slider( | |
label="Max new tokens", | |
minimum=1, | |
maximum=MAX_MAX_NEW_TOKENS, | |
step=1, | |
value=DEFAULT_MAX_NEW_TOKENS, | |
) | |
temperature = gr.Slider( | |
label="Temperature", | |
minimum=0.1, | |
maximum=4.0, | |
step=0.1, | |
value=0.1, | |
) | |
top_p = gr.Slider( | |
label="Top-p (nucleus sampling)", | |
minimum=0.05, | |
maximum=1.0, | |
step=0.05, | |
value=0.9, | |
) | |
top_k = gr.Slider( | |
label="Top-k", | |
minimum=1, | |
maximum=1000, | |
step=1, | |
value=10, | |
) | |
textbox.submit( | |
fn=clear_and_save_textbox, | |
inputs=textbox, | |
outputs=[textbox, saved_input], | |
).then( | |
fn=display_input, | |
inputs=[saved_input, chatbot], | |
outputs=chatbot, | |
).then( | |
fn=check_input_token_length, | |
inputs=[saved_input, chatbot, system_prompt], | |
).success( | |
fn=generate, | |
inputs=[saved_input, chatbot, system_prompt, max_new_tokens, temperature, top_p, top_k], | |
outputs=chatbot, | |
) | |
submit_button.click( | |
fn=clear_and_save_textbox, | |
inputs=textbox, | |
outputs=[textbox, saved_input], | |
).then( | |
fn=display_input, | |
inputs=[saved_input, chatbot], | |
outputs=chatbot, | |
).then( | |
fn=check_input_token_length, | |
inputs=[saved_input, chatbot, system_prompt], | |
).success( | |
fn=generate, | |
inputs=[saved_input, chatbot, system_prompt, max_new_tokens, temperature, top_p, top_k], | |
outputs=chatbot, | |
) | |
retry_button.click( | |
fn=delete_prev_fn, | |
inputs=chatbot, | |
outputs=[chatbot, saved_input], | |
).then( | |
fn=display_input, | |
inputs=[saved_input, chatbot], | |
outputs=chatbot, | |
).then( | |
fn=generate, | |
inputs=[saved_input, chatbot, system_prompt, max_new_tokens, temperature, top_p, top_k], | |
outputs=chatbot, | |
) | |
undo_button.click( | |
fn=delete_prev_fn, | |
inputs=chatbot, | |
outputs=[chatbot, saved_input], | |
).then( | |
fn=lambda x: x, | |
inputs=[saved_input], | |
outputs=textbox, | |
) | |
clear_button.click( | |
fn=lambda: ([], ""), | |
outputs=[chatbot, saved_input], | |
) | |
demo.queue(max_size=32).launch(share=False) | |