HAH-2024-v0.1 / app.py
drmasad's picture
Update app.py
89a2585 verified
raw
history blame
4.77 kB
# Import necessary libraries
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
from openai import OpenAI
import os
import torch
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
from huggingface_hub import login
# Initialize the OpenAI client (if needed for Hugging Face Inference API)
client = OpenAI(
base_url="https://api-inference.huggingface.co/v1",
api_key=os.environ.get("HUGGINGFACEHUB_API_TOKEN"),
)
api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
if api_token:
login(token=api_token)
else:
print("API token is not set in the environment variables.")
# Define model links and configurations
model_links = {
"HAH-2024-v0.1": "drmasad/HAH-2024-v0.11",
"Mistral": "mistralai/Mistral-7B-Instruct-v0.2",
}
# Define sidebar options
models = list(model_links.keys())
# Sidebar model selection
selected_model = st.sidebar.selectbox("Select Model", models)
# Sidebar temperature control
temp_values = st.sidebar.slider("Select a temperature value", 0.0, 1.0, (0.5))
# Reset conversation functionality
def reset_conversation():
st.session_state.conversation = []
st.session_state.messages = []
st.sidebar.button("Reset Chat", on_click=reset_conversation)
# Display model information on the sidebar
model_info = {
"HAH-2024-v0.1": {
"description": "HAH-2024-v0.1 is a fine-tuned model based on Mistral 7B. It's designed for conversations on diabetes.",
"logo": "https://www.hmgaihub.com/untitled.png",
},
"Mistral": {
"description": "Mistral is a large language model with multi-task capabilities.",
"logo": "https://mistral.ai/images/logo_hubc88c4ece131b91c7cb753f40e9e1cc5_2589_256x0_resize_q97_h2_lanczos_3.webp",
},
}
st.sidebar.write(f"You're now chatting with **{selected_model}**")
st.sidebar.markdown(model_info[selected_model]["description"])
st.sidebar.image(model_info[selected_model]["logo"])
# Load the appropriate model based on user selection
def load_model(selected_model_name):
if selected_model_name == "HAH-2024-v0.1":
# Setup for HAH-2024-v0.1
model_name = model_links["HAH-2024-v0.1"]
base_model = "mistralai/Mistral-7B-Instruct-v0.2"
# Load model with quantization configuration
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=False,
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
model.config.use_cache = False
model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=64,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj"],
)
model = get_peft_model(model, peft_config)
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
elif selected_model_name == "Mistral":
# Setup for Mistral 7B
model = AutoModelForCausalLM.from_pretrained(
model_links[selected_model_name]
)
tokenizer = AutoTokenizer.from_pretrained(model_links[selected_model_name])
return model, tokenizer
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
# Load the selected model
model, tokenizer = load_model(selected_model)
st.subheader(f"AI - {selected_model}")
# Display previous chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# User input for conversation
if prompt := st.chat_input("Ask a question"):
# Display user input
with st.chat_message("user"):
st.markdown(prompt)
# Store the user message
st.session_state.messages.append({"role": "user", "content": prompt})
# Generate the assistant's response
with st.chat_message("assistant"):
pipe = pipeline(
task="text-generation",
model=model,
tokenizer=tokenizer,
max_length=1024,
temperature=temp_values
)
result = pipe(f"<s>[INST] {prompt}</s>", do_sample=True)
response = result[0]["generated_text"]
st.markdown(response)
# Store the assistant's response
st.session_state.messages.append({"role": "assistant", "content": response})