drmasad commited on
Commit
bc1d2e1
·
verified ·
1 Parent(s): 693c6e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -95
app.py CHANGED
@@ -5,44 +5,61 @@ Diabetes Version
5
  @email: [email protected]
6
  """
7
 
8
- # Import necessary libraries
9
  import streamlit as st
10
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
11
  from openai import OpenAI
12
  import os
13
- import torch
14
- from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
15
- from huggingface_hub import login
16
 
17
- # Initialize the OpenAI client
 
18
  client = OpenAI(
19
  base_url="https://api-inference.huggingface.co/v1",
20
- api_key=os.environ.get("HUGGINGFACEHUB_API_TOKEN"),
 
21
  )
22
 
23
- api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
24
- if api_token:
25
- login(token=api_token)
26
- else:
27
- st.error("API token is not set in the environment variables.")
28
-
29
- # Define model links
30
  model_links = {
31
- "HAH-2024-v0.1": "drmasad/HAH-2024-v0.11"
 
32
  }
33
 
34
- # Set selected model
35
- selected_model = "HAH-2024-v0.1"
 
 
 
 
 
 
 
 
 
 
36
 
37
- # Display welcome message
38
- st.title("Welcome to HAH-2024-v0.1")
39
 
40
- # Sidebar setup
41
- temp_values = st.sidebar.slider("Select a temperature value", 0.0, 1.0, (0.5))
42
  def reset_conversation():
 
 
 
43
  st.session_state.conversation = []
44
  st.session_state.messages = []
 
 
 
 
 
 
 
 
45
 
 
 
 
 
 
46
  st.sidebar.button("Reset Chat", on_click=reset_conversation)
47
  st.sidebar.write(f"You're now chatting with **{selected_model}**")
48
  st.sidebar.image("https://www.hmgaihub.com/untitled.png")
@@ -50,89 +67,61 @@ st.sidebar.markdown("*Generated content may be inaccurate or false.*")
50
  st.sidebar.markdown("*This is an under development project.*")
51
  st.sidebar.markdown("*Not a replacement for medical advice from a doctor.*")
52
 
53
- # Loading message placeholder
54
- loading_message = st.empty()
55
-
56
- def load_model(selected_model_name):
57
- loading_message.info("Loading the model, please wait...")
58
- model_name = model_links[selected_model_name]
59
-
60
- # Define the BitsAndBytesConfig for quantization
61
- bnb_config = BitsAndBytesConfig(
62
- load_in_4bit=True,
63
- bnb_4bit_quant_type="nf4",
64
- bnb_4bit_compute_dtype=torch.bfloat16,
65
- bnb_4bit_use_double_quant=False,
66
- llm_int8_enable_fp32_cpu_offload=True,
67
- )
68
-
69
- # Load the model with quantization settings directly applied
70
- model = AutoModelForCausalLM.from_pretrained(
71
- model_name,
72
- quantization_config=bnb_config,
73
- trust_remote_code=True,
74
- )
75
-
76
- # Additional configurations and training enhancements
77
- model.config.use_cache = False
78
- model = prepare_model_for_kbit_training(model)
79
-
80
- # If using PEFT or other enhancements, configure here
81
- peft_config = LoraConfig(
82
- lora_alpha=16,
83
- lora_dropout=0.1,
84
- r=64,
85
- bias="none",
86
- task_type="CAUSAL_LM",
87
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj"],
88
- )
89
- model = get_peft_model(model, peft_config)
90
-
91
- tokenizer = AutoTokenizer.from_pretrained(
92
- "mistralai/Mistral-7B-Instruct-v0.2", trust_remote_code=True
93
- )
94
-
95
- # Clear the loading message
96
- loading_message.success("Model is ready. Now we are ready!")
97
-
98
- return model, tokenizer
99
-
100
-
101
- # Load model and tokenizer
102
- model, tokenizer = load_model(selected_model)
103
-
104
- # Chat application logic
105
  if "messages" not in st.session_state:
106
  st.session_state.messages = []
107
 
 
 
108
  for message in st.session_state.messages:
109
  with st.chat_message(message["role"]):
110
  st.markdown(message["content"])
111
 
112
- if prompt := st.chat_input("Ask me anything about diabetes"):
 
 
 
 
113
  with st.chat_message("user"):
114
  st.markdown(prompt)
115
-
116
  st.session_state.messages.append({"role": "user", "content": prompt})
117
 
118
- instructions = """
119
- Act as a highly knowledgeable doctor with special interest in diabetes, skilled at explaining complex medical information in a way that is easy to understand for patients without a medical background. Your responses should not only demonstrate empathy and care but also uphold a high standard of medical accuracy and reliability. Respond precisely to what the patient needs in a professional, accurate, and reassuring manner, avoiding any unnecessary information.
120
- """
121
-
122
- full_prompt = f"<s>[INST] {prompt} [/INST] {instructions}</s>"
123
-
124
  with st.chat_message("assistant"):
125
- result = pipeline(
126
- task="text-generation",
127
- model=model,
128
- tokenizer=tokenizer,
129
- max_length=1024,
130
- temperature=temp_values
131
- )(full_prompt)
132
-
133
- generated_text = result[0]['generated_text']
134
- response = generated_text.split("</s>")[-1].strip()
135
-
136
- st.markdown(response)
137
-
138
- st.session_state.messages.append({"role": "assistant", "content": response})
 
5
  @email: [email protected]
6
  """
7
 
 
8
  import streamlit as st
 
9
  from openai import OpenAI
10
  import os
11
+ import sys
12
+ from dotenv import load_dotenv, dotenv_values
13
+ load_dotenv()
14
 
15
+
16
+ # initialize the client
17
  client = OpenAI(
18
  base_url="https://api-inference.huggingface.co/v1",
19
+ # "hf_xxx" # Replace with your token
20
+ api_key=os.environ.get('HUGGINGFACEHUB_API_TOKEN')
21
  )
22
 
23
+ # Create supported models
 
 
 
 
 
 
24
  model_links = {
25
+ "HAH v0.1": "drmasad/HAH-2024-v0.11",
26
+ "Mistral": "mistralai/Mistral-7B-Instruct-v0.2",
27
  }
28
 
29
+ # Pull info about the model to display
30
+ model_info = {
31
+ "HAH v0.1":
32
+ {'description': """HAH 0.1 is a fine tuned model based on Mistral 7b instruct.\n \
33
+ \nIt was created by Dr M. As'ad using 250k dB rows sourced from open source articles on diabetes** \n""",
34
+ 'logo': 'https://www.hmgaihub.com/untitled.png'},
35
+ "Mistral":
36
+ {'description': """The Mistral model is a **Large Language Model (LLM)** that's able to have question and answer interactions.\n \
37
+ \nIt was created by the [**Mistral AI**](https://mistral.ai/news/announcing-mistral-7b/) team as has over **7 billion parameters.** \n""",
38
+ 'logo': 'https://mistral.ai/images/logo_hubc88c4ece131b91c7cb753f40e9e1cc5_2589_256x0_resize_q97_h2_lanczos_3.webp'},
39
+
40
+ }
41
 
 
 
42
 
 
 
43
  def reset_conversation():
44
+ '''
45
+ Resets Conversation
46
+ '''
47
  st.session_state.conversation = []
48
  st.session_state.messages = []
49
+ return None
50
+
51
+
52
+ # Define the available models
53
+ models = [key for key in model_links.keys()]
54
+
55
+ # Create the sidebar with the dropdown for model selection
56
+ selected_model = st.sidebar.selectbox("Select Model", models)
57
 
58
+ # Create a temperature slider
59
+ temp_values = st.sidebar.slider('Select a temperature value', 0.0, 1.0, (0.5))
60
+
61
+
62
+ # Create model description
63
  st.sidebar.button("Reset Chat", on_click=reset_conversation)
64
  st.sidebar.write(f"You're now chatting with **{selected_model}**")
65
  st.sidebar.image("https://www.hmgaihub.com/untitled.png")
 
67
  st.sidebar.markdown("*This is an under development project.*")
68
  st.sidebar.markdown("*Not a replacement for medical advice from a doctor.*")
69
 
70
+
71
+ if "prev_option" not in st.session_state:
72
+ st.session_state.prev_option = selected_model
73
+
74
+ if st.session_state.prev_option != selected_model:
75
+ st.session_state.messages = []
76
+ # st.write(f"Changed to {selected_model}")
77
+ st.session_state.prev_option = selected_model
78
+ reset_conversation()
79
+
80
+
81
+ # Pull in the model we want to use
82
+ repo_id = model_links[selected_model]
83
+
84
+
85
+ st.subheader(f'AI - {selected_model}')
86
+ # st.title(f'ChatBot Using {selected_model}')
87
+
88
+ # Set a default model
89
+ if selected_model not in st.session_state:
90
+ st.session_state[selected_model] = model_links[selected_model]
91
+
92
+ # Initialize chat history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  if "messages" not in st.session_state:
94
  st.session_state.messages = []
95
 
96
+
97
+ # Display chat messages from history on app rerun
98
  for message in st.session_state.messages:
99
  with st.chat_message(message["role"]):
100
  st.markdown(message["content"])
101
 
102
+
103
+ # Accept user input
104
+ if prompt := st.chat_input(f"Hi I'm {selected_model}, ask me a question"):
105
+
106
+ # Display user message in chat message container
107
  with st.chat_message("user"):
108
  st.markdown(prompt)
109
+ # Add user message to chat history
110
  st.session_state.messages.append({"role": "user", "content": prompt})
111
 
112
+ # Display assistant response in chat message container
 
 
 
 
 
113
  with st.chat_message("assistant"):
114
+ stream = client.chat.completions.create(
115
+ model=model_links[selected_model],
116
+ messages=[
117
+ {"role": m["role"], "content": m["content"]}
118
+ for m in st.session_state.messages
119
+ ],
120
+ temperature=temp_values, # 0.5,
121
+ stream=True,
122
+ max_tokens=3000,
123
+ )
124
+
125
+ response = st.write_stream(stream)
126
+ st.session_state.messages.append(
127
+ {"role": "assistant", "content": response})