drmasad commited on
Commit
849fdd5
·
verified ·
1 Parent(s): 2b5d309

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -81
app.py CHANGED
@@ -1,120 +1,144 @@
1
- """ HAH 2024 v0.1 Chatbot
2
- @author: Dr M As'ad
3
- @email: [email protected]
4
- """
5
-
6
  import streamlit as st
 
7
  from openai import OpenAI
8
  import os
9
- import sys
10
- from dotenv import load_dotenv, dotenv_values
 
 
 
11
  load_dotenv()
12
 
13
- # initialize the client
14
- client = OpenAI(
15
- base_url="https://api-inference.huggingface.co/v1",
16
- api_key=os.environ.get('HUGGINGFACEHUB_API_TOKEN')
17
- )
18
-
19
- #Create supported models
20
- model_links ={
21
- "HAH-2024-v0.1":"drmasad/HAH-2024-v0.11",
22
- "Mistral":"mistralai/Mistral-7B-Instruct-v0.2",
23
  }
24
 
25
- model_info ={
26
- "HAH-2024-v0.1":
27
- {'description':"""The HAH-2024-v0.1 model is a **Large Language Model (LLM)** that's able to have question and answer interactions.\n \
28
- \nIt was created by fine tuning Mistra 7b instruct usin 3000 review articles on diabetes** \n""",
29
- 'logo':'https://www.hmgaihub.com/untitled.png'},
30
- "Mistral":
31
- {'description':
32
- """The Mistral model is a **Large Language Model (LLM)** that's able to have question and answer interactions...""",
33
- 'logo':
34
- 'https://mistral.ai/images/logo_hubc88c4ece131b91c7cb753f40e9e1cc5_2589_256x0_resize_q97_h2_lanczos_3.webp'},
35
- }
36
 
37
- def reset_conversation():
38
- '''
39
- Resets Conversation
40
- '''
41
- st.session_state.conversation = []
42
- st.session_state.messages = []
43
- return None
44
-
45
- # Define the available models
46
- models =[key for key in model_links.keys()]
47
 
48
- # Create the sidebar with the dropdown for model selection
49
  selected_model = st.sidebar.selectbox("Select Model", models)
50
 
51
- #Create a temperature slider
52
- temp_values = st.sidebar.slider('Select a temperature value', 0.0, 1.0, (0.5))
53
-
54
 
55
- #Add reset button to clear conversation
56
- st.sidebar.button('Reset Chat', on_click=reset_conversation) #Reset button
 
 
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- # Create model description
60
  st.sidebar.write(f"You're now chatting with **{selected_model}**")
61
- st.sidebar.markdown(model_info[selected_model]['description'])
62
- st.sidebar.image(model_info[selected_model]['logo'])
63
- st.sidebar.markdown("*Generated content may be inaccurate or false.*")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
 
 
 
 
 
 
 
65
 
66
- if "prev_option" not in st.session_state:
67
- st.session_state.prev_option = selected_model
68
 
69
- if st.session_state.prev_option != selected_model:
70
- st.session_state.messages = []
71
- # st.write(f"Changed to {selected_model}")
72
- st.session_state.prev_option = selected_model
73
- reset_conversation()
 
 
 
74
 
75
- #Pull in the model we want to use
76
- repo_id = model_links[selected_model]
77
 
 
78
 
79
- st.subheader(f'AI - {selected_model}')
80
- # st.title(f'ChatBot Using {selected_model}')
 
 
 
 
81
 
82
- # Set a default model
83
- if selected_model not in st.session_state:
84
- st.session_state[selected_model] = model_links[selected_model]
85
 
86
  # Initialize chat history
87
  if "messages" not in st.session_state:
88
  st.session_state.messages = []
89
 
 
 
90
 
91
- # Display chat messages from history on app rerun
 
 
92
  for message in st.session_state.messages:
93
  with st.chat_message(message["role"]):
94
  st.markdown(message["content"])
95
 
96
- # Accept user input
97
- if prompt := st.chat_input(f"Hi I'm {selected_model}, ask me a question"):
98
-
99
- # Display user message in chat message container
100
  with st.chat_message("user"):
101
  st.markdown(prompt)
102
- # Add user message to chat history
103
- st.session_state.messages.append({"role": "user", "content": prompt})
104
 
 
 
105
 
106
- # Display assistant response in chat message container
107
  with st.chat_message("assistant"):
108
- stream = client.chat.completions.create(
109
- model=model_links[selected_model],
110
- messages=[
111
- {"role": m["role"], "content": m["content"]}
112
- for m in st.session_state.messages
113
- ],
114
- temperature=temp_values,#0.5,
115
- stream=True,
116
- max_tokens=3000,
117
  )
 
 
 
 
 
118
 
119
- response = st.write_stream(stream)
120
- st.session_state.messages.append({"role": "assistant", "content": response})
 
1
+ # Import necessary libraries
 
 
 
 
2
  import streamlit as st
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
4
  from openai import OpenAI
5
  import os
6
+ import torch
7
+ from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
8
+ from dotenv import load_dotenv
9
+
10
+ # Load environment variables
11
  load_dotenv()
12
 
13
+ # Define model links and configurations
14
+ model_links = {
15
+ "HAH-2024-v0.1": "drmasad/HAH-2024-v0.11",
16
+ "Mistral": "mistralai/Mistral-7B-Instruct-v0.2",
 
 
 
 
 
 
17
  }
18
 
19
+ # Initialize the OpenAI client (if needed for Hugging Face Inference API)
20
+ client = OpenAI(
21
+ base_url="https://api-inference.huggingface.co/v1",
22
+ api_key=os.environ.get("HUGGINGFACEHUB_API_TOKEN"),
23
+ )
 
 
 
 
 
 
24
 
25
+ # Define sidebar options
26
+ models = list(model_links.keys())
 
 
 
 
 
 
 
 
27
 
28
+ # Sidebar model selection
29
  selected_model = st.sidebar.selectbox("Select Model", models)
30
 
31
+ # Sidebar temperature control
32
+ temp_values = st.sidebar.slider("Select a temperature value", 0.0, 1.0, (0.5))
 
33
 
34
+ # Reset conversation functionality
35
+ def reset_conversation():
36
+ st.session_state.conversation = []
37
+ st.session_state.messages = []
38
 
39
+ st.sidebar.button("Reset Chat", on_click=reset_conversation)
40
+
41
+ # Display model information on the sidebar
42
+ model_info = {
43
+ "HAH-2024-v0.1": {
44
+ "description": "HAH-2024-v0.1 is a fine-tuned model based on Mistral 7B. It's designed for conversations on diabetes.",
45
+ "logo": "https://www.hmgaihub.com/untitled.png",
46
+ },
47
+ "Mistral": {
48
+ "description": "Mistral is a large language model with multi-task capabilities.",
49
+ "logo": "https://mistral.ai/images/logo_hubc88c4ece131b91c7cb753f40e9e1cc5_2589_256x0_resize_q97_h2_lanczos_3.webp",
50
+ },
51
+ }
52
 
 
53
  st.sidebar.write(f"You're now chatting with **{selected_model}**")
54
+ st.sidebar.markdown(model_info[selected_model]["description"])
55
+ st.sidebar.image(model_info[selected_model]["logo"])
56
+
57
+ # Load the appropriate model based on user selection
58
+ def load_model(selected_model_name):
59
+ if selected_model_name == "HAH-2024-v0.1":
60
+ # Setup for HAH-2024-v0.1
61
+ model_name = model_links["HAH-2024-v0.1"]
62
+ base_model = "mistralai/Mistral-7B-Instruct-v0.2"
63
+
64
+ # Load model with quantization configuration
65
+ bnb_config = BitsAndBytesConfig(
66
+ load_in_4bit=True,
67
+ bnb_4bit_quant_type="nf4",
68
+ bnb_4bit_compute_dtype=torch.bfloat16,
69
+ bnb_4bit_use_double_quant=False,
70
+ )
71
 
72
+ model = AutoModelForCausalLM.from_pretrained(
73
+ model_name,
74
+ quantization_config=bnb_config,
75
+ torch_dtype=torch.bfloat16,
76
+ device_map="auto",
77
+ trust_remote_code=True,
78
+ )
79
 
80
+ model.config.use_cache = False
81
+ model = prepare_model_for_kbit_training(model)
82
 
83
+ peft_config = LoraConfig(
84
+ lora_alpha=16,
85
+ lora_dropout=0.1,
86
+ r=64,
87
+ bias="none",
88
+ task_type="CAUSAL_LM",
89
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj"],
90
+ )
91
 
92
+ model = get_peft_model(model, peft_config)
 
93
 
94
+ tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
95
 
96
+ elif selected_model_name == "Mistral":
97
+ # Setup for Mistral 7B
98
+ model = AutoModelForCausalLM.from_pretrained(
99
+ model_links[selected_model_name]
100
+ )
101
+ tokenizer = AutoTokenizer.from_pretrained(model_links[selected_model_name])
102
 
103
+ return model, tokenizer
 
 
104
 
105
  # Initialize chat history
106
  if "messages" not in st.session_state:
107
  st.session_state.messages = []
108
 
109
+ # Load the selected model
110
+ model, tokenizer = load_model(selected_model)
111
 
112
+ st.subheader(f"AI - {selected_model}")
113
+
114
+ # Display previous chat messages
115
  for message in st.session_state.messages:
116
  with st.chat_message(message["role"]):
117
  st.markdown(message["content"])
118
 
119
+ # User input for conversation
120
+ if prompt := st.chat_input("Ask a question"):
121
+ # Display user input
 
122
  with st.chat_message("user"):
123
  st.markdown(prompt)
 
 
124
 
125
+ # Store the user message
126
+ st.session_state.messages.append({"role": "user", "content": prompt})
127
 
128
+ # Generate the assistant's response
129
  with st.chat_message("assistant"):
130
+ pipe = pipeline(
131
+ task="text-generation",
132
+ model=model,
133
+ tokenizer=tokenizer,
134
+ max_length=1024,
135
+ temperature=temp_values
 
 
 
136
  )
137
+
138
+ result = pipe(f"<s>[INST] {prompt}</s>", do_sample=True)
139
+ response = result[0]["generated_text"]
140
+
141
+ st.markdown(response)
142
 
143
+ # Store the assistant's response
144
+ st.session_state.messages.append({"role": "assistant", "content": response})