drmasad commited on
Commit
ee59722
·
verified ·
1 Parent(s): 5fd79f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -76
app.py CHANGED
@@ -7,7 +7,6 @@ import torch
7
  from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
8
  from huggingface_hub import login
9
 
10
-
11
  # Initialize the OpenAI client (if needed for Hugging Face Inference API)
12
  client = OpenAI(
13
  base_url="https://api-inference.huggingface.co/v1",
@@ -15,7 +14,6 @@ client = OpenAI(
15
  )
16
 
17
  api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
18
-
19
  if api_token:
20
  login(token=api_token)
21
  else:
@@ -23,15 +21,11 @@ else:
23
 
24
  # Define model links and configurations
25
  model_links = {
26
- "HAH-2024-v0.1": "drmasad/HAH-2024-v0.11",
27
- "Mistral": "mistralai/Mistral-7B-Instruct-v0.2",
28
  }
29
 
30
  # Define sidebar options
31
- models = list(model_links.keys())
32
-
33
- # Sidebar model selection
34
- selected_model = st.sidebar.selectbox("Select Model", models)
35
 
36
  # Sidebar temperature control
37
  temp_values = st.sidebar.slider("Select a temperature value", 0.0, 1.0, (0.5))
@@ -48,102 +42,80 @@ model_info = {
48
  "HAH-2024-v0.1": {
49
  "description": "HAH-2024-v0.1 is a fine-tuned model based on Mistral 7B. It's designed for conversations on diabetes.",
50
  "logo": "https://www.hmgaihub.com/untitled.png",
51
- },
52
- "Mistral": {
53
- "description": "Mistral is a large language model with multi-task capabilities.",
54
- "logo": "https://mistral.ai/images/logo_hubc88c4ece131b91c7cb753f40e9e1cc5_2589_256x0_resize_q97_h2_lanczos_3.webp",
55
- },
56
  }
57
 
58
  st.sidebar.write(f"You're now chatting with **{selected_model}**")
59
  st.sidebar.markdown(model_info[selected_model]["description"])
60
  st.sidebar.image(model_info[selected_model]["logo"])
61
 
62
- # Load the appropriate model based on user selection
63
- def load_model(selected_model_name):
64
- if selected_model_name == "HAH-2024-v0.1":
65
- # Setup for HAH-2024-v0.1
66
- model_name = model_links["HAH-2024-v0.1"]
67
- base_model = "mistralai/Mistral-7B-Instruct-v0.2"
68
-
69
- # Load model with quantization configuration
70
- bnb_config = BitsAndBytesConfig(
71
- load_in_4bit=True,
72
- bnb_4bit_quant_type="nf4",
73
- bnb_4bit_compute_dtype=torch.bfloat16,
74
- bnb_4bit_use_double_quant=False,
75
- )
76
-
77
- model = AutoModelForCausalLM.from_pretrained(
78
- model_name,
79
- quantization_config=bnb_config,
80
- torch_dtype=torch.bfloat16,
81
- device_map="auto",
82
- trust_remote_code=True,
83
- )
84
-
85
- model.config.use_cache = False
86
- model = prepare_model_for_kbit_training(model)
87
-
88
- peft_config = LoraConfig(
89
- lora_alpha=16,
90
- lora_dropout=0.1,
91
- r=64,
92
- bias="none",
93
- task_type="CAUSAL_LM",
94
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj"],
95
- )
96
-
97
- model = get_peft_model(model, peft_config)
98
-
99
- tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
100
-
101
- elif selected_model_name == "Mistral":
102
- # Setup for Mistral 7B
103
- model = AutoModelForCausalLM.from_pretrained(
104
- model_links[selected_model_name]
105
- )
106
- tokenizer = AutoTokenizer.from_pretrained(model_links[selected_model_name])
107
 
108
  return model, tokenizer
109
 
 
 
110
  # Initialize chat history
111
  if "messages" not in st.session_state:
112
  st.session_state.messages = []
113
 
114
- # Load the selected model
115
- model, tokenizer = load_model(selected_model)
116
-
117
- st.subheader(f"AI - {selected_model}")
118
-
119
  # Display previous chat messages
120
  for message in st.session_state.messages:
121
  with st.chat_message(message["role"]):
122
  st.markdown(message["content"])
123
 
124
  # User input for conversation
125
- if prompt := st.chat_input("Ask a question"):
126
- # Display user input
127
  with st.chat_message("user"):
128
  st.markdown(prompt)
129
-
130
- # Store the user message
131
  st.session_state.messages.append({"role": "user", "content": prompt})
132
-
133
- # Generate the assistant's response
134
  with st.chat_message("assistant"):
135
- pipe = pipeline(
136
  task="text-generation",
137
  model=model,
138
  tokenizer=tokenizer,
139
  max_length=1024,
140
  temperature=temp_values
141
- )
142
 
143
- result = pipe(f"<s>[INST] {prompt}</s>", do_sample=True)
144
- response = result[0]["generated_text"]
145
-
146
  st.markdown(response)
147
-
148
- # Store the assistant's response
149
  st.session_state.messages.append({"role": "assistant", "content": response})
 
7
  from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
8
  from huggingface_hub import login
9
 
 
10
  # Initialize the OpenAI client (if needed for Hugging Face Inference API)
11
  client = OpenAI(
12
  base_url="https://api-inference.huggingface.co/v1",
 
14
  )
15
 
16
  api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
 
17
  if api_token:
18
  login(token=api_token)
19
  else:
 
21
 
22
  # Define model links and configurations
23
  model_links = {
24
+ "HAH-2024-v0.1": "drmasad/HAH-2024-v0.11"
 
25
  }
26
 
27
  # Define sidebar options
28
+ selected_model = "HAH-2024-v0.1" # Directly using your model
 
 
 
29
 
30
  # Sidebar temperature control
31
  temp_values = st.sidebar.slider("Select a temperature value", 0.0, 1.0, (0.5))
 
42
  "HAH-2024-v0.1": {
43
  "description": "HAH-2024-v0.1 is a fine-tuned model based on Mistral 7B. It's designed for conversations on diabetes.",
44
  "logo": "https://www.hmgaihub.com/untitled.png",
45
+ }
 
 
 
 
46
  }
47
 
48
  st.sidebar.write(f"You're now chatting with **{selected_model}**")
49
  st.sidebar.markdown(model_info[selected_model]["description"])
50
  st.sidebar.image(model_info[selected_model]["logo"])
51
 
52
+ # Load the appropriate model
53
+ def load_model():
54
+ model_name = model_links["HAH-2024-v0.1"]
55
+ base_model = "mistralai/Mistral-7B-Instruct-v0.2"
56
+
57
+ # Load model with quantization configuration
58
+ bnb_config = BitsAndBytesConfig(
59
+ load_in_4bit=True,
60
+ bnb_4bit_quant_type="nf4",
61
+ bnb_4bit_compute_dtype=torch.bfloat16,
62
+ bnb_4bit_use_double_quant=False,
63
+ )
64
+
65
+ model = AutoModelForCausalLM.from_pretrained(
66
+ model_name,
67
+ quantization_config=bnb_config,
68
+ torch_dtype=torch.bfloat16,
69
+ device_map="auto",
70
+ trust_remote_code=True,
71
+ )
72
+
73
+ model.config.use_cache = False
74
+ model = prepare_model_for_kbit_training(model)
75
+
76
+ peft_config = LoraConfig(
77
+ lora_alpha=16,
78
+ lora_dropout=0.1,
79
+ r=64,
80
+ bias="none",
81
+ task_type="CAUSAL_LM",
82
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj"],
83
+ )
84
+
85
+ model = get_peft_model(model, peft_config)
86
+
87
+ tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
88
 
89
  return model, tokenizer
90
 
91
+ model, tokenizer = load_model()
92
+
93
  # Initialize chat history
94
  if "messages" not in st.session_state:
95
  st.session_state.messages = []
96
 
 
 
 
 
 
97
  # Display previous chat messages
98
  for message in st.session_state.messages:
99
  with st.chat_message(message["role"]):
100
  st.markdown(message["content"])
101
 
102
  # User input for conversation
103
+ if prompt := st.chat_input("Ask me anything about diabetes"):
 
104
  with st.chat_message("user"):
105
  st.markdown(prompt)
106
+
 
107
  st.session_state.messages.append({"role": "user", "content": prompt})
108
+
 
109
  with st.chat_message("assistant"):
110
+ result = pipeline(
111
  task="text-generation",
112
  model=model,
113
  tokenizer=tokenizer,
114
  max_length=1024,
115
  temperature=temp_values
116
+ )(prompt)
117
 
118
+ response = result[0]['generated_text']
 
 
119
  st.markdown(response)
120
+
 
121
  st.session_state.messages.append({"role": "assistant", "content": response})