# Import necessary libraries import streamlit as st from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline from openai import OpenAI import os import torch from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model from huggingface_hub import login # Load environment variables load_dotenv() # Initialize the OpenAI client (if needed for Hugging Face Inference API) client = OpenAI( base_url="https://api-inference.huggingface.co/v1", api_key=os.environ.get("HUGGINGFACEHUB_API_TOKEN"), ) api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") if api_token: login(token=api_token) else: print("API token is not set in the environment variables.") # Define model links and configurations model_links = { "HAH-2024-v0.1": "drmasad/HAH-2024-v0.11", "Mistral": "mistralai/Mistral-7B-Instruct-v0.2", } # Define sidebar options models = list(model_links.keys()) # Sidebar model selection selected_model = st.sidebar.selectbox("Select Model", models) # Sidebar temperature control temp_values = st.sidebar.slider("Select a temperature value", 0.0, 1.0, (0.5)) # Reset conversation functionality def reset_conversation(): st.session_state.conversation = [] st.session_state.messages = [] st.sidebar.button("Reset Chat", on_click=reset_conversation) # Display model information on the sidebar model_info = { "HAH-2024-v0.1": { "description": "HAH-2024-v0.1 is a fine-tuned model based on Mistral 7B. It's designed for conversations on diabetes.", "logo": "https://www.hmgaihub.com/untitled.png", }, "Mistral": { "description": "Mistral is a large language model with multi-task capabilities.", "logo": "https://mistral.ai/images/logo_hubc88c4ece131b91c7cb753f40e9e1cc5_2589_256x0_resize_q97_h2_lanczos_3.webp", }, } st.sidebar.write(f"You're now chatting with **{selected_model}**") st.sidebar.markdown(model_info[selected_model]["description"]) st.sidebar.image(model_info[selected_model]["logo"]) # Load the appropriate model based on user selection def load_model(selected_model_name): if selected_model_name == "HAH-2024-v0.1": # Setup for HAH-2024-v0.1 model_name = model_links["HAH-2024-v0.1"] base_model = "mistralai/Mistral-7B-Instruct-v0.2" # Load model with quantization configuration bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=False, ) model = AutoModelForCausalLM.from_pretrained( model_name, quantization_config=bnb_config, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, ) model.config.use_cache = False model = prepare_model_for_kbit_training(model) peft_config = LoraConfig( lora_alpha=16, lora_dropout=0.1, r=64, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj"], ) model = get_peft_model(model, peft_config) tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) elif selected_model_name == "Mistral": # Setup for Mistral 7B model = AutoModelForCausalLM.from_pretrained( model_links[selected_model_name] ) tokenizer = AutoTokenizer.from_pretrained(model_links[selected_model_name]) return model, tokenizer # Initialize chat history if "messages" not in st.session_state: st.session_state.messages = [] # Load the selected model model, tokenizer = load_model(selected_model) st.subheader(f"AI - {selected_model}") # Display previous chat messages for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) # User input for conversation if prompt := st.chat_input("Ask a question"): # Display user input with st.chat_message("user"): st.markdown(prompt) # Store the user message st.session_state.messages.append({"role": "user", "content": prompt}) # Generate the assistant's response with st.chat_message("assistant"): pipe = pipeline( task="text-generation", model=model, tokenizer=tokenizer, max_length=1024, temperature=temp_values ) result = pipe(f"[INST] {prompt}", do_sample=True) response = result[0]["generated_text"] st.markdown(response) # Store the assistant's response st.session_state.messages.append({"role": "assistant", "content": response})