Spaces:
Runtime error
Runtime error
# Import necessary libraries | |
import streamlit as st | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline | |
from openai import OpenAI | |
import os | |
import torch | |
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model | |
from huggingface_hub import login | |
# Initialize the OpenAI client (if needed for Hugging Face Inference API) | |
client = OpenAI( | |
base_url="https://api-inference.huggingface.co/v1", | |
api_key=os.environ.get("HUGGINGFACEHUB_API_TOKEN"), | |
) | |
api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
if api_token: | |
login(token=api_token) | |
else: | |
print("API token is not set in the environment variables.") | |
# Define model links and configurations | |
model_links = { | |
"HAH-2024-v0.1": "drmasad/HAH-2024-v0.11", | |
"Mistral": "mistralai/Mistral-7B-Instruct-v0.2", | |
} | |
# Define sidebar options | |
models = list(model_links.keys()) | |
# Sidebar model selection | |
selected_model = st.sidebar.selectbox("Select Model", models) | |
# Sidebar temperature control | |
temp_values = st.sidebar.slider("Select a temperature value", 0.0, 1.0, (0.5)) | |
# Reset conversation functionality | |
def reset_conversation(): | |
st.session_state.conversation = [] | |
st.session_state.messages = [] | |
st.sidebar.button("Reset Chat", on_click=reset_conversation) | |
# Display model information on the sidebar | |
model_info = { | |
"HAH-2024-v0.1": { | |
"description": "HAH-2024-v0.1 is a fine-tuned model based on Mistral 7B. It's designed for conversations on diabetes.", | |
"logo": "https://www.hmgaihub.com/untitled.png", | |
}, | |
"Mistral": { | |
"description": "Mistral is a large language model with multi-task capabilities.", | |
"logo": "https://mistral.ai/images/logo_hubc88c4ece131b91c7cb753f40e9e1cc5_2589_256x0_resize_q97_h2_lanczos_3.webp", | |
}, | |
} | |
st.sidebar.write(f"You're now chatting with **{selected_model}**") | |
st.sidebar.markdown(model_info[selected_model]["description"]) | |
st.sidebar.image(model_info[selected_model]["logo"]) | |
# Load the appropriate model based on user selection | |
def load_model(selected_model_name): | |
if selected_model_name == "HAH-2024-v0.1": | |
# Setup for HAH-2024-v0.1 | |
model_name = model_links["HAH-2024-v0.1"] | |
base_model = "mistralai/Mistral-7B-Instruct-v0.2" | |
# Load model with quantization configuration | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
bnb_4bit_use_double_quant=False, | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
quantization_config=bnb_config, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
trust_remote_code=True, | |
) | |
model.config.use_cache = False | |
model = prepare_model_for_kbit_training(model) | |
peft_config = LoraConfig( | |
lora_alpha=16, | |
lora_dropout=0.1, | |
r=64, | |
bias="none", | |
task_type="CAUSAL_LM", | |
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj"], | |
) | |
model = get_peft_model(model, peft_config) | |
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) | |
elif selected_model_name == "Mistral": | |
# Setup for Mistral 7B | |
model = AutoModelForCausalLM.from_pretrained( | |
model_links[selected_model_name] | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_links[selected_model_name]) | |
return model, tokenizer | |
# Initialize chat history | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# Load the selected model | |
model, tokenizer = load_model(selected_model) | |
st.subheader(f"AI - {selected_model}") | |
# Display previous chat messages | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# User input for conversation | |
if prompt := st.chat_input("Ask a question"): | |
# Display user input | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
# Store the user message | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
# Generate the assistant's response | |
with st.chat_message("assistant"): | |
pipe = pipeline( | |
task="text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_length=1024, | |
temperature=temp_values | |
) | |
result = pipe(f"<s>[INST] {prompt}</s>", do_sample=True) | |
response = result[0]["generated_text"] | |
st.markdown(response) | |
# Store the assistant's response | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |