HAH-2024-v0.1 / app.py
drmasad's picture
Update app.py
693c6e9 verified
raw
history blame
4.55 kB
"""
Diabetes Version
@aim: Demo for testing purposes only
@inquiries: Dr M As'ad
@email: [email protected]
"""
# Import necessary libraries
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
from openai import OpenAI
import os
import torch
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
from huggingface_hub import login
# Initialize the OpenAI client
client = OpenAI(
base_url="https://api-inference.huggingface.co/v1",
api_key=os.environ.get("HUGGINGFACEHUB_API_TOKEN"),
)
api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
if api_token:
login(token=api_token)
else:
st.error("API token is not set in the environment variables.")
# Define model links
model_links = {
"HAH-2024-v0.1": "drmasad/HAH-2024-v0.11"
}
# Set selected model
selected_model = "HAH-2024-v0.1"
# Display welcome message
st.title("Welcome to HAH-2024-v0.1")
# Sidebar setup
temp_values = st.sidebar.slider("Select a temperature value", 0.0, 1.0, (0.5))
def reset_conversation():
st.session_state.conversation = []
st.session_state.messages = []
st.sidebar.button("Reset Chat", on_click=reset_conversation)
st.sidebar.write(f"You're now chatting with **{selected_model}**")
st.sidebar.image("https://www.hmgaihub.com/untitled.png")
st.sidebar.markdown("*Generated content may be inaccurate or false.*")
st.sidebar.markdown("*This is an under development project.*")
st.sidebar.markdown("*Not a replacement for medical advice from a doctor.*")
# Loading message placeholder
loading_message = st.empty()
def load_model(selected_model_name):
loading_message.info("Loading the model, please wait...")
model_name = model_links[selected_model_name]
# Define the BitsAndBytesConfig for quantization
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=False,
llm_int8_enable_fp32_cpu_offload=True,
)
# Load the model with quantization settings directly applied
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
trust_remote_code=True,
)
# Additional configurations and training enhancements
model.config.use_cache = False
model = prepare_model_for_kbit_training(model)
# If using PEFT or other enhancements, configure here
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=64,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj"],
)
model = get_peft_model(model, peft_config)
tokenizer = AutoTokenizer.from_pretrained(
"mistralai/Mistral-7B-Instruct-v0.2", trust_remote_code=True
)
# Clear the loading message
loading_message.success("Model is ready. Now we are ready!")
return model, tokenizer
# Load model and tokenizer
model, tokenizer = load_model(selected_model)
# Chat application logic
if "messages" not in st.session_state:
st.session_state.messages = []
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if prompt := st.chat_input("Ask me anything about diabetes"):
with st.chat_message("user"):
st.markdown(prompt)
st.session_state.messages.append({"role": "user", "content": prompt})
instructions = """
Act as a highly knowledgeable doctor with special interest in diabetes, skilled at explaining complex medical information in a way that is easy to understand for patients without a medical background. Your responses should not only demonstrate empathy and care but also uphold a high standard of medical accuracy and reliability. Respond precisely to what the patient needs in a professional, accurate, and reassuring manner, avoiding any unnecessary information.
"""
full_prompt = f"<s>[INST] {prompt} [/INST] {instructions}</s>"
with st.chat_message("assistant"):
result = pipeline(
task="text-generation",
model=model,
tokenizer=tokenizer,
max_length=1024,
temperature=temp_values
)(full_prompt)
generated_text = result[0]['generated_text']
response = generated_text.split("</s>")[-1].strip()
st.markdown(response)
st.session_state.messages.append({"role": "assistant", "content": response})