Omdena-Phoenix / app.py
Mattral's picture
Update app.py
36eb6b1 verified
raw
history blame
7.04 kB
import gradio as gr
from typing import Iterator, List, Tuple
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftConfig, PeftModel
base_model = "mistralai/Mistral-7B-Instruct-v0.2"
adapter = "GRMenon/mental-health-mistral-7b-instructv0.2-finetuned-V2"
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
base_model,
add_bos_token=True,
trust_remote_code=True,
padding_side='left'
)
# Create peft model using base_model and finetuned adapter
config = PeftConfig.from_pretrained(adapter)
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path,
load_in_4bit=True,
device_map='auto',
torch_dtype='auto')
model = PeftModel.from_pretrained(model, adapter)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
DEFAULT_SYSTEM_PROMPT = "You are Phoenix AI Healthcare. You are professional, you are polite, give only truthful information and are based on the Mistral-7B model from Mistral AI about Healtcare and Wellness. You can communicate in different languages equally well."
MAX_MAX_NEW_TOKENS = 4096
DEFAULT_MAX_NEW_TOKENS = 256
MAX_INPUT_TOKEN_LENGTH = 4000
DESCRIPTION = """
# Simple Healthcare Chatbot
### Powered by Mistral-7B with Healthcare Fine-Tuning
"""
def clear_and_save_textbox(message: str) -> tuple[str, str]:
return "", message
def display_input(message: str, history: list[tuple[str, str]]) -> list[tuple[str, str]]:
history.append((message, ""))
return history
def delete_prev_fn(history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
try:
message, _ = history.pop()
except IndexError:
message = ""
return history, message or ""
def generate(
message: str,
history_with_input: list[tuple[str, str]],
system_prompt: str,
max_new_tokens: int,
temperature: float,
top_p: float,
top_k: int,
) -> Iterator[list[tuple[str, str]]]:
if max_new_tokens > MAX_MAX_NEW_TOKENS:
raise ValueError("Max new tokens exceeded")
history = history_with_input[:-1]
conversation = [{"role": "system", "content": system_prompt}] + \
[{"role": "user", "content": user_input} for user_input, _ in history] + \
[{"role": "user", "content": message}]
input_ids = tokenizer.apply_chat_template(conversation=conversation,
tokenize=True,
add_generation_prompt=True,
return_tensors='pt').to(device)
output_ids = model.generate(input_ids=input_ids, max_new_tokens=max_new_tokens,
do_sample=True, pad_token_id=tokenizer.pad_token_id)
response = tokenizer.batch_decode(output_ids.detach().cpu().numpy(), skip_special_tokens=True)
response_text = response[0]
yield history + [(message, response_text)]
def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
input_token_length = len(tokenizer.encode(message)) + sum(len(tokenizer.encode(msg)) for msg, _ in chat_history)
if input_token_length > MAX_INPUT_TOKEN_LENGTH:
raise gr.Error(f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.")
with gr.Blocks(css="./styles/style.css") as demo: # Link to CSS file
gr.Markdown(DESCRIPTION)
gr.Button("Duplicate Space for private use", elem_id="duplicate-button")
with gr.Group():
chatbot = gr.Chatbot(label="Chat with Healthcare AI")
with gr.Row():
textbox = gr.Textbox(
container=False,
show_label=False,
placeholder="Ask me anything about Healthcare and Wellness...",
scale=10,
)
submit_button = gr.Button("Submit", variant="primary", scale=1, min_width=0)
with gr.Row():
retry_button = gr.Button('πŸ”„ Retry', variant='secondary')
undo_button = gr.Button('↩️ Undo', variant='secondary')
clear_button = gr.Button('πŸ—‘οΈ Clear', variant='secondary')
saved_input = gr.State()
with gr.Accordion(label="βš™οΈ Advanced options", open=False):
system_prompt = gr.Textbox(
label="System prompt",
value=DEFAULT_SYSTEM_PROMPT,
lines=5,
interactive=False,
)
max_new_tokens = gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
)
temperature = gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.1,
)
top_p = gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
)
top_k = gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=10,
)
textbox.submit(
fn=clear_and_save_textbox,
inputs=textbox,
outputs=[textbox, saved_input],
).then(
fn=display_input,
inputs=[saved_input, chatbot],
outputs=chatbot,
).then(
fn=check_input_token_length,
inputs=[saved_input, chatbot, system_prompt],
).success(
fn=generate,
inputs=[saved_input, chatbot, system_prompt, max_new_tokens, temperature, top_p, top_k],
outputs=chatbot,
)
submit_button.click(
fn=clear_and_save_textbox,
inputs=textbox,
outputs=[textbox, saved_input],
).then(
fn=display_input,
inputs=[saved_input, chatbot],
outputs=chatbot,
).then(
fn=check_input_token_length,
inputs=[saved_input, chatbot, system_prompt],
).success(
fn=generate,
inputs=[saved_input, chatbot, system_prompt, max_new_tokens, temperature, top_p, top_k],
outputs=chatbot,
)
retry_button.click(
fn=delete_prev_fn,
inputs=chatbot,
outputs=[chatbot, saved_input],
).then(
fn=display_input,
inputs=[saved_input, chatbot],
outputs=chatbot,
).then(
fn=generate,
inputs=[saved_input, chatbot, system_prompt, max_new_tokens, temperature, top_p, top_k],
outputs=chatbot,
)
undo_button.click(
fn=delete_prev_fn,
inputs=chatbot,
outputs=[chatbot, saved_input],
).then(
fn=lambda x: x,
inputs=[saved_input],
outputs=textbox,
)
clear_button.click(
fn=lambda: ([], ""),
outputs=[chatbot, saved_input],
)
demo.queue(max_size=32).launch(share=False)