HAH-2024-v0.1 / app.py
drmasad's picture
Update app.py
e659be2 verified
raw
history blame
3.35 kB
# Import necessary libraries
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
from openai import OpenAI
import os
import torch
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
from huggingface_hub import login
# Initialize the OpenAI client
client = OpenAI(
base_url="https://api-inference.huggingface.co/v1",
api_key=os.environ.get("HUGGINGFACEHUB_API_TOKEN"),
)
api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
if api_token:
login(token=api_token)
else:
print("API token is not set in the environment variables.")
# Define model links
model_links = {
"HAH-2024-v0.1": "drmasad/HAH-2024-v0.11"
}
# Set selected model
selected_model = "HAH-2024-v0.1"
# Sidebar setup
temp_values = st.sidebar.slider("Select a temperature value", 0.0, 1.0, (0.5))
def reset_conversation():
st.session_state.conversation = []
st.session_state.messages = []
st.sidebar.button("Reset Chat", on_click=reset_conversation)
st.sidebar.write(f"You're now chatting with **{selected_model}**")
st.sidebar.image("https://www.hmgaihub.com/untitled.png")
def load_model(selected_model_name):
model_name = model_links[selected_model_name]
# Set a specific device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model with device mapping
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=False,
llm_int8_enable_fp32_cpu_offload=True,
)
device_map = {"": device} # Default device for all components
# Load model with proper device mapping
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map=device_map, # Assign device
trust_remote_code=True,
)
model.config.use_cache = False
model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=64,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj"],
)
model = get_peft_model(model, peft_config)
tokenizer = AutoTokenizer.from_pretrained(
"mistralai/Mistral-7B-Instruct-v0.2", trust_remote_code=True
)
return model, tokenizer
# Load model and tokenizer
model, tokenizer = load_model(selected_model)
# Chat application logic
if "messages" not in st.session_state:
st.session_state.messages = []
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if prompt := st.chat_input("Ask me anything about diabetes"):
with st.chat_message("user"):
st.markdown(prompt)
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("assistant"):
result = pipeline(
task="text-generation",
model=model,
tokenizer=tokenizer,
max_length=1024,
temperature=temp_values
)(prompt)
response = result[0]['generated_text']
st.markdown(response)
st.session_state.messages.append({"role": "assistant", "content": response})