File size: 4,773 Bytes
849fdd5
23624f5
849fdd5
23624f5
 
849fdd5
 
b1d9540
849fdd5
23624f5
849fdd5
 
 
 
 
23624f5
b1d9540
 
 
 
 
 
adc12fa
b790aae
 
 
 
 
 
849fdd5
 
23624f5
849fdd5
23624f5
 
849fdd5
 
23624f5
849fdd5
 
 
 
23624f5
849fdd5
 
 
 
 
 
 
 
 
 
 
 
 
23624f5
 
849fdd5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23624f5
849fdd5
 
 
 
 
 
 
23624f5
849fdd5
 
23624f5
849fdd5
 
 
 
 
 
 
 
23624f5
849fdd5
23624f5
849fdd5
23624f5
849fdd5
 
 
 
 
 
23624f5
849fdd5
23624f5
 
 
 
 
849fdd5
 
23624f5
849fdd5
 
 
23624f5
 
 
 
849fdd5
 
 
23624f5
 
 
849fdd5
 
23624f5
849fdd5
23624f5
849fdd5
 
 
 
 
 
23624f5
849fdd5
 
 
 
 
23624f5
849fdd5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# Import necessary libraries
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
from openai import OpenAI
import os
import torch
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
from huggingface_hub import login


# Initialize the OpenAI client (if needed for Hugging Face Inference API)
client = OpenAI(
    base_url="https://api-inference.huggingface.co/v1",
    api_key=os.environ.get("HUGGINGFACEHUB_API_TOKEN"),
)

api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")

if api_token:
    login(token=api_token)
else:
    print("API token is not set in the environment variables.")

# Define model links and configurations
model_links = {
    "HAH-2024-v0.1": "drmasad/HAH-2024-v0.11",
    "Mistral": "mistralai/Mistral-7B-Instruct-v0.2",
}

# Define sidebar options
models = list(model_links.keys())

# Sidebar model selection
selected_model = st.sidebar.selectbox("Select Model", models)

# Sidebar temperature control
temp_values = st.sidebar.slider("Select a temperature value", 0.0, 1.0, (0.5))

# Reset conversation functionality
def reset_conversation():
    st.session_state.conversation = []
    st.session_state.messages = []

st.sidebar.button("Reset Chat", on_click=reset_conversation)

# Display model information on the sidebar
model_info = {
    "HAH-2024-v0.1": {
        "description": "HAH-2024-v0.1 is a fine-tuned model based on Mistral 7B. It's designed for conversations on diabetes.",
        "logo": "https://www.hmgaihub.com/untitled.png",
    },
    "Mistral": {
        "description": "Mistral is a large language model with multi-task capabilities.",
        "logo": "https://mistral.ai/images/logo_hubc88c4ece131b91c7cb753f40e9e1cc5_2589_256x0_resize_q97_h2_lanczos_3.webp",
    },
}

st.sidebar.write(f"You're now chatting with **{selected_model}**")
st.sidebar.markdown(model_info[selected_model]["description"])
st.sidebar.image(model_info[selected_model]["logo"])

# Load the appropriate model based on user selection
def load_model(selected_model_name):
    if selected_model_name == "HAH-2024-v0.1":
        # Setup for HAH-2024-v0.1
        model_name = model_links["HAH-2024-v0.1"]
        base_model = "mistralai/Mistral-7B-Instruct-v0.2"

        # Load model with quantization configuration
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=False,
        )

        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            quantization_config=bnb_config,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True,
        )

        model.config.use_cache = False
        model = prepare_model_for_kbit_training(model)

        peft_config = LoraConfig(
            lora_alpha=16,
            lora_dropout=0.1,
            r=64,
            bias="none",
            task_type="CAUSAL_LM",
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj"],
        )

        model = get_peft_model(model, peft_config)

        tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)

    elif selected_model_name == "Mistral":
        # Setup for Mistral 7B
        model = AutoModelForCausalLM.from_pretrained(
            model_links[selected_model_name]
        )
        tokenizer = AutoTokenizer.from_pretrained(model_links[selected_model_name])

    return model, tokenizer

# Initialize chat history
if "messages" not in st.session_state:
    st.session_state.messages = []

# Load the selected model
model, tokenizer = load_model(selected_model)

st.subheader(f"AI - {selected_model}")

# Display previous chat messages
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

# User input for conversation
if prompt := st.chat_input("Ask a question"):
    # Display user input
    with st.chat_message("user"):
        st.markdown(prompt)

    # Store the user message
    st.session_state.messages.append({"role": "user", "content": prompt})

    # Generate the assistant's response
    with st.chat_message("assistant"):
        pipe = pipeline(
            task="text-generation", 
            model=model, 
            tokenizer=tokenizer, 
            max_length=1024, 
            temperature=temp_values
        )
        
        result = pipe(f"<s>[INST] {prompt}</s>", do_sample=True)
        response = result[0]["generated_text"]

        st.markdown(response)

    # Store the assistant's response
    st.session_state.messages.append({"role": "assistant", "content": response})