File size: 3,346 Bytes
849fdd5
23624f5
849fdd5
23624f5
 
849fdd5
 
b1d9540
849fdd5
8666754
849fdd5
 
 
 
23624f5
b1d9540
 
 
 
 
adc12fa
8666754
b790aae
ee59722
b790aae
 
8666754
 
23624f5
8666754
849fdd5
 
 
 
23624f5
849fdd5
23624f5
8666754
849fdd5
8666754
 
e659be2
 
 
 
 
ee59722
 
 
 
 
e659be2
ee59722
e659be2
 
 
 
ee59722
e659be2
 
 
 
ee59722
e659be2
ee59722
 
e659be2
 
 
 
 
 
 
 
 
 
ee59722
e659be2
 
 
 
 
849fdd5
23624f5
e659be2
8666754
 
ee59722
8666754
23624f5
 
 
 
 
 
 
ee59722
23624f5
 
849fdd5
8666754
23624f5
ee59722
8666754
 
 
 
849fdd5
ee59722
 
849fdd5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Import necessary libraries
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
from openai import OpenAI
import os
import torch
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
from huggingface_hub import login

# Initialize the OpenAI client
client = OpenAI(
    base_url="https://api-inference.huggingface.co/v1",
    api_key=os.environ.get("HUGGINGFACEHUB_API_TOKEN"),
)

api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
if api_token:
    login(token=api_token)
else:
    print("API token is not set in the environment variables.")

# Define model links
model_links = {
    "HAH-2024-v0.1": "drmasad/HAH-2024-v0.11"
}

# Set selected model
selected_model = "HAH-2024-v0.1"

# Sidebar setup
temp_values = st.sidebar.slider("Select a temperature value", 0.0, 1.0, (0.5))
def reset_conversation():
    st.session_state.conversation = []
    st.session_state.messages = []

st.sidebar.button("Reset Chat", on_click=reset_conversation)
st.sidebar.write(f"You're now chatting with **{selected_model}**")
st.sidebar.image("https://www.hmgaihub.com/untitled.png")

def load_model(selected_model_name):
    model_name = model_links[selected_model_name]

    # Set a specific device
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Load model with device mapping
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=False,
        llm_int8_enable_fp32_cpu_offload=True,
    )

    device_map = {"": device}  # Default device for all components

    # Load model with proper device mapping
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config=bnb_config,
        device_map=device_map,  # Assign device
        trust_remote_code=True,
    )

    model.config.use_cache = False
    model = prepare_model_for_kbit_training(model)

    peft_config = LoraConfig(
        lora_alpha=16,
        lora_dropout=0.1,
        r=64,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj"],
    )

    model = get_peft_model(model, peft_config)

    tokenizer = AutoTokenizer.from_pretrained(
        "mistralai/Mistral-7B-Instruct-v0.2", trust_remote_code=True
    )

    return model, tokenizer


# Load model and tokenizer
model, tokenizer = load_model(selected_model)

# Chat application logic
if "messages" not in st.session_state:
    st.session_state.messages = []

for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

if prompt := st.chat_input("Ask me anything about diabetes"):
    with st.chat_message("user"):
        st.markdown(prompt)
    st.session_state.messages.append({"role": "user", "content": prompt})

    with st.chat_message("assistant"):
        result = pipeline(
            task="text-generation",
            model=model,
            tokenizer=tokenizer,
            max_length=1024,
            temperature=temp_values
        )(prompt)
        response = result[0]['generated_text']
        st.markdown(response)
    st.session_state.messages.append({"role": "assistant", "content": response})