Spaces:
Runtime error
Runtime error
# Import necessary libraries | |
import streamlit as st | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline | |
from openai import OpenAI | |
import os | |
import torch | |
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model | |
from huggingface_hub import login | |
# Initialize the OpenAI client | |
client = OpenAI( | |
base_url="https://api-inference.huggingface.co/v1", | |
api_key=os.environ.get("HUGGINGFACEHUB_API_TOKEN"), | |
) | |
api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
if api_token: | |
login(token=api_token) | |
else: | |
st.error("API token is not set in the environment variables.") | |
# Define model links | |
model_links = { | |
"HAH-2024-v0.1": "drmasad/HAH-2024-v0.11" | |
} | |
# Set selected model | |
selected_model = "HAH-2024-v0.1" | |
# Display welcome message | |
st.title("Welcome to HAH-2024-v0.1") | |
# Sidebar setup | |
temp_values = st.sidebar.slider("Select a temperature value", 0.0, 1.0, (0.5)) | |
def reset_conversation(): | |
st.session_state.conversation = [] | |
st.session_state.messages = [] | |
st.sidebar.button("Reset Chat", on_click=reset_conversation) | |
st.sidebar.write(f"You're now chatting with **{selected_model}**") | |
st.sidebar.image("https://www.hmgaihub.com/untitled.png") | |
st.sidebar.markdown("*Generated content may be inaccurate or false.*") | |
st.sidebar.markdown("*This is an under development project.*") | |
# Function to load model | |
def load_model(selected_model_name): | |
st.info("Loading the model, please wait...") | |
model_name = model_links[selected_model_name] | |
# Set a specific device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
bnb_4bit_use_double_quant=False, | |
llm_int8_enable_fp32_cpu_offload=True, | |
) | |
device_map = { | |
'encoder.layer.0': 'cuda', # Keep specific parts on GPU | |
'decoder': 'cpu', # Offload others to CPU | |
} | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
quantization_config=bnb_config, | |
device_map=device_map, | |
trust_remote_code=True, | |
) | |
model.config.use_cache = False | |
model = prepare_model_for_kbit_training(model) | |
peft_config = LoraConfig( | |
lora_alpha=16, | |
lora_dropout=0.1, | |
r=64, | |
bias="none", | |
task_type="CAUSAL_LM", | |
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj"], | |
) | |
model = get_peft_model(model, peft_config) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"mistralai/Mistral-7B-Instruct-v0.2", trust_remote_code=True | |
) | |
st.success("Model is ready. Now we are ready!") | |
return model, tokenizer | |
# Load model and tokenizer | |
model, tokenizer = load_model(selected_model) | |
# Chat application logic | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
if prompt := st.chat_input("Ask me anything about diabetes"): | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
instructions = """ | |
Act as a highly knowledgeable endocrinology doctor with expertise in explaining complex medical information in an understandable way to patients who do not have a medical background. Your responses should not only convey empathy and care but also demonstrate a high level of medical accuracy and reliability. | |
you will answer only what the need and in professional way. do not add extra unnecessary information. you can however chat with the patient casually | |
""" | |
full_prompt = f"<s>[INST] {instructions} [/INST] {prompt}</s>" | |
with st.chat_message("assistant"): | |
result = pipeline( | |
task="text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_length=1024, | |
temperature=temp_values | |
)(full_prompt) | |
generated_text = result[0]['generated_text'] | |
response = generated_text.split("</s>")[-1].strip() | |
st.markdown(response) | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |