ayush0504 commited on
Commit
53ec504
·
verified ·
1 Parent(s): 3321568

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +218 -183
app.py CHANGED
@@ -1,183 +1,218 @@
1
- import streamlit as st
2
- import torch
3
- from peft import AutoPeftModelForCausalLM
4
- from transformers import AutoTokenizer, TextStreamer
5
- import io
6
- import sys
7
- import threading
8
- import time
9
- import queue # Import the queue module
10
-
11
- # --- Configuration ---
12
- DEFAULT_MODEL_PATH = "lora_model" # Or your default path
13
- DEFAULT_LOAD_IN_4BIT = True
14
-
15
- # --- Page Configuration ---
16
- st.set_page_config(page_title="Fine-tuned LLM Chat Interface", layout="wide")
17
- st.title("Fine-tuned LLM Chat Interface")
18
-
19
- # --- Model Loading (Cached) ---
20
- @st.cache_resource(show_spinner="Loading model and tokenizer...")
21
- def load_model_and_tokenizer(model_path, load_in_4bit):
22
- """Loads the PEFT model and tokenizer."""
23
- try:
24
- torch_dtype = torch.bfloat16 if load_in_4bit else torch.float16 # bfloat16 often better for 4-bit
25
-
26
- model = AutoPeftModelForCausalLM.from_pretrained(
27
- model_path,
28
- torch_dtype=torch_dtype,
29
- load_in_4bit=load_in_4bit,
30
- device_map="auto",
31
- )
32
- tokenizer = AutoTokenizer.from_pretrained(model_path)
33
- model.eval()
34
- print("Model and tokenizer loaded successfully.")
35
- return model, tokenizer
36
- except Exception as e:
37
- st.error(f"Error loading model from path '{model_path}': {e}", icon="🚨")
38
- print(f"Error loading model: {e}")
39
- return None, None
40
-
41
- # --- Custom Streamer Class (Modified for Queue) ---
42
- class QueueStreamer(TextStreamer):
43
- def __init__(self, tokenizer, skip_prompt, q):
44
- super().__init__(tokenizer, skip_prompt=skip_prompt)
45
- self.queue = q
46
- self.stop_signal = None # Can be used if needed, but queue is primary
47
-
48
- def on_finalized_text(self, text: str, stream_end: bool = False):
49
- """Puts the text onto the queue."""
50
- self.queue.put(text)
51
- if stream_end:
52
- self.end()
53
-
54
- def end(self):
55
- """Signals the end of generation by putting None in the queue."""
56
- self.queue.put(self.stop_signal) # Put None (or a specific sentinel)
57
-
58
-
59
- # --- Sidebar for Settings ---
60
- with st.sidebar:
61
- st.header("Model Configuration")
62
- st.info(f"Model loaded on startup: `{DEFAULT_MODEL_PATH}`, 4-bit: `{DEFAULT_LOAD_IN_4BIT}`.")
63
-
64
- st.header("Generation Settings")
65
- temperature = st.slider("Temperature", min_value=0.0, max_value=2.0, value=0.7, step=0.05)
66
- min_p = st.slider("Min P", min_value=0.01, max_value=1.0, value=0.1, step=0.01)
67
- max_tokens = st.slider("Max New Tokens", min_value=50, max_value=2048, value=512, step=50)
68
-
69
- if st.button("Clear Chat History"):
70
- st.session_state.messages = []
71
- st.rerun() # Rerun to clear display immediately
72
-
73
-
74
- # --- Load Model (runs only once on first run or if cache is cleared) ---
75
- model, tokenizer = load_model_and_tokenizer(DEFAULT_MODEL_PATH, DEFAULT_LOAD_IN_4BIT)
76
-
77
- # --- Initialize Session State ---
78
- if "messages" not in st.session_state:
79
- st.session_state.messages = []
80
-
81
- # --- Main Chat Interface ---
82
- if model is None or tokenizer is None:
83
- st.error("Model loading failed. Please check the path and logs. Cannot proceed.")
84
- st.stop()
85
-
86
- # Display conversation history
87
- for message in st.session_state.messages:
88
- with st.chat_message(message["role"]):
89
- st.markdown(message["content"])
90
-
91
- # Handle user input
92
- user_input = st.chat_input("Ask the fine-tuned model...")
93
-
94
- if user_input:
95
- # Add user message to history and display it
96
- st.session_state.messages.append({"role": "user", "content": user_input})
97
- with st.chat_message("user"):
98
- st.markdown(user_input)
99
-
100
- # Prepare for model response
101
- with st.chat_message("assistant"):
102
- response_placeholder = st.empty()
103
- text_queue = queue.Queue() # Create a queue for this specific response
104
- # Initialize the modified streamer
105
- text_streamer = QueueStreamer(tokenizer, skip_prompt=True, q=text_queue)
106
-
107
- # Prepare input for the model
108
- messages_for_model = st.session_state.messages
109
-
110
- try:
111
- if tokenizer.chat_template:
112
- inputs = tokenizer.apply_chat_template(
113
- messages_for_model,
114
- tokenize=True,
115
- add_generation_prompt=True,
116
- return_tensors="pt"
117
- ).to(model.device)
118
- else:
119
- prompt_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages_for_model]) + "\nassistant:"
120
- inputs = tokenizer(prompt_text, return_tensors="pt").input_ids.to(model.device)
121
-
122
- # Generation arguments
123
- generation_kwargs = dict(
124
- input_ids=inputs,
125
- streamer=text_streamer, # Use the QueueStreamer
126
- max_new_tokens=max_tokens,
127
- use_cache=True,
128
- temperature=temperature if temperature > 0 else None,
129
- top_p=None,
130
- min_p=min_p,
131
- do_sample=True if temperature > 0 else False,
132
- eos_token_id=tokenizer.eos_token_id,
133
- pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id
134
- )
135
-
136
- # Define the target function for the thread
137
- def generation_thread_func():
138
- try:
139
- # Run generation in the background thread
140
- model.generate(**generation_kwargs)
141
- except Exception as e:
142
- # If error occurs in thread, signal stop and maybe log
143
- print(f"Error in generation thread: {e}")
144
- text_streamer.end() # Ensure the queue loop terminates
145
-
146
- # Start the generation thread
147
- thread = threading.Thread(target=generation_thread_func)
148
- thread.start()
149
-
150
- # --- Main thread: Read from queue and update UI ---
151
- generated_text = ""
152
- while True:
153
- try:
154
- # Get the next text chunk from the queue
155
- # Use timeout to prevent blocking indefinitely if thread hangs
156
- chunk = text_queue.get(block=True, timeout=1)
157
- if chunk is text_streamer.stop_signal: # Check for end signal (None)
158
- break
159
- generated_text += chunk
160
- response_placeholder.markdown(generated_text + "▌") # Update placeholder
161
- except queue.Empty:
162
- # If queue is empty, check if the generation thread is still running
163
- if not thread.is_alive():
164
- # Thread finished, but maybe didn't put the stop signal (error?)
165
- break # Exit loop
166
- # Otherwise, continue waiting for next chunk
167
- continue
168
-
169
- # Final update without the cursor
170
- response_placeholder.markdown(generated_text)
171
-
172
- # Add the complete assistant response to history *after* generation
173
- st.session_state.messages.append({"role": "assistant", "content": generated_text})
174
-
175
- # Wait briefly for the thread to finish if it hasn't already
176
- thread.join(timeout=2.0)
177
-
178
-
179
- except Exception as e:
180
- st.error(f"Error during generation setup or queue handling: {e}", icon="🔥")
181
- print(f"Error setting up generation or handling queue: {e}")
182
- st.session_state.messages.append({"role": "assistant", "content": f"*Error generating response: {e}*"})
183
- response_placeholder.error(f"Error generating response: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from peft import AutoPeftModelForCausalLM
4
+ from transformers import AutoTokenizer, TextStreamer
5
+ # bitsandbytes is no longer needed
6
+ import io
7
+ import sys
8
+ import threading
9
+ import time
10
+ import queue # Import the queue module
11
+
12
+ # --- Configuration ---
13
+ DEFAULT_MODEL_PATH = "lora_model" # Or your default path
14
+ # DEFAULT_LOAD_IN_4BIT is removed as we are not using quantization
15
+
16
+ # --- Page Configuration ---
17
+ st.set_page_config(page_title="Fine-tuned LLM Chat Interface (CPU)", layout="wide")
18
+ st.title("Fine-tuned LLM Chat Interface (CPU Mode)")
19
+ st.warning("Running in CPU mode. Expect slower generation times and higher RAM usage.", icon="⚠️")
20
+
21
+ # --- Model Loading (Cached for CPU) ---
22
+ @st.cache_resource(show_spinner="Loading model and tokenizer onto CPU...")
23
+ def load_model_and_tokenizer_cpu(model_path):
24
+ """Loads the PEFT model and tokenizer onto the CPU."""
25
+ try:
26
+ # Use standard float32 for CPU compatibility and stability
27
+ torch_dtype = torch.float32
28
+
29
+ model = AutoPeftModelForCausalLM.from_pretrained(
30
+ model_path,
31
+ torch_dtype=torch_dtype,
32
+ # load_in_4bit=False, # Explicitly removed/not needed
33
+ device_map="cpu", # Force loading onto CPU
34
+ )
35
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
36
+ model.eval() # Set model to evaluation mode
37
+ print("Model and tokenizer loaded successfully onto CPU.")
38
+ return model, tokenizer
39
+ except Exception as e:
40
+ st.error(f"Error loading model from path '{model_path}' onto CPU: {e}", icon="🚨")
41
+ print(f"Error loading model onto CPU: {e}")
42
+ return None, None
43
+
44
+ # --- Custom Streamer Class (Modified for Queue) ---
45
+ class QueueStreamer(TextStreamer):
46
+ def __init__(self, tokenizer, skip_prompt, q):
47
+ super().__init__(tokenizer, skip_prompt=skip_prompt)
48
+ self.queue = q
49
+ self.stop_signal = None # Can be used if needed, but queue is primary
50
+
51
+ def on_finalized_text(self, text: str, stream_end: bool = False):
52
+ """Puts the text onto the queue."""
53
+ self.queue.put(text)
54
+ if stream_end:
55
+ self.end()
56
+
57
+ def end(self):
58
+ """Signals the end of generation by putting None in the queue."""
59
+ self.queue.put(self.stop_signal) # Put None (or a specific sentinel)
60
+
61
+
62
+ # --- Sidebar for Settings ---
63
+ with st.sidebar:
64
+ st.header("Model Configuration")
65
+ st.info(f"Model loaded on startup: `{DEFAULT_MODEL_PATH}` (CPU Mode).")
66
+
67
+ st.header("Generation Settings")
68
+ temperature = st.slider("Temperature", min_value=0.0, max_value=2.0, value=0.7, step=0.05)
69
+ # min_p might not be as commonly used or effective without top_p/top_k,
70
+ # but keeping it allows experimentation. Consider using top_k or top_p instead.
71
+ # Example: top_p = st.slider("Top P", min_value=0.01, max_value=1.0, value=0.9, step=0.01)
72
+ min_p = st.slider("Min P", min_value=0.01, max_value=1.0, value=0.1, step=0.01) # Keep for now
73
+ max_tokens = st.slider("Max New Tokens", min_value=50, max_value=2048, value=256, step=50) # Reduced default for CPU
74
+
75
+ if st.button("Clear Chat History"):
76
+ st.session_state.messages = []
77
+ st.rerun() # Rerun to clear display immediately
78
+
79
+
80
+ # --- Load Model (runs only once on first run or if cache is cleared) ---
81
+ model, tokenizer = load_model_and_tokenizer_cpu(DEFAULT_MODEL_PATH)
82
+
83
+ # --- Initialize Session State ---
84
+ if "messages" not in st.session_state:
85
+ st.session_state.messages = []
86
+
87
+ # --- Main Chat Interface ---
88
+ if model is None or tokenizer is None:
89
+ st.error("CPU Model loading failed. Please check the path, available RAM, and logs. Cannot proceed.")
90
+ st.stop()
91
+
92
+ # Display conversation history
93
+ for message in st.session_state.messages:
94
+ with st.chat_message(message["role"]):
95
+ st.markdown(message["content"])
96
+
97
+ # Handle user input
98
+ user_input = st.chat_input("Ask the fine-tuned model (CPU)...")
99
+
100
+ if user_input:
101
+ # Add user message to history and display it
102
+ st.session_state.messages.append({"role": "user", "content": user_input})
103
+ with st.chat_message("user"):
104
+ st.markdown(user_input)
105
+
106
+ # Prepare for model response
107
+ with st.chat_message("assistant"):
108
+ response_placeholder = st.empty()
109
+ response_placeholder.markdown("Generating response on CPU... please wait... ▌") # Initial message
110
+ text_queue = queue.Queue() # Create a queue for this specific response
111
+ # Initialize the modified streamer
112
+ text_streamer = QueueStreamer(tokenizer, skip_prompt=True, q=text_queue)
113
+
114
+ # Prepare input for the model
115
+ messages_for_model = st.session_state.messages
116
+
117
+ try:
118
+ # Ensure inputs are on the CPU (model.device should be 'cpu' now)
119
+ target_device = model.device
120
+ # print(f"Model device: {target_device}") # Debugging: should print 'cpu'
121
+
122
+ if tokenizer.chat_template:
123
+ inputs = tokenizer.apply_chat_template(
124
+ messages_for_model,
125
+ tokenize=True,
126
+ add_generation_prompt=True,
127
+ return_tensors="pt"
128
+ ).to(target_device) # Send input tensors to CPU
129
+ else:
130
+ prompt_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages_for_model]) + "\nassistant:"
131
+ inputs = tokenizer(prompt_text, return_tensors="pt").input_ids.to(target_device) # Send input tensors to CPU
132
+
133
+ # Generation arguments
134
+ generation_kwargs = dict(
135
+ input_ids=inputs,
136
+ streamer=text_streamer, # Use the QueueStreamer
137
+ max_new_tokens=max_tokens,
138
+ use_cache=True, # Caching can still help CPU generation speed
139
+ temperature=temperature if temperature > 0 else None,
140
+ top_p=None, # Consider adding top_p slider in UI
141
+ # top_k=50, # Example: Or use top_k
142
+ min_p=min_p,
143
+ do_sample=True if temperature > 0 else False,
144
+ eos_token_id=tokenizer.eos_token_id,
145
+ pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id
146
+ )
147
+
148
+ # Define the target function for the thread
149
+ def generation_thread_func():
150
+ try:
151
+ # Run generation in the background thread (on CPU)
152
+ # Wrap in torch.no_grad() to save memory during inference
153
+ with torch.no_grad():
154
+ model.generate(**generation_kwargs)
155
+ except Exception as e:
156
+ # If error occurs in thread, signal stop and maybe log
157
+ print(f"Error in generation thread: {e}")
158
+ # Attempt to put error message in queue? Or just rely on main thread error handling
159
+ st.error(f"Error during generation: {e}") # Show error in UI too
160
+ finally:
161
+ # Ensure the queue loop terminates even if error occurred
162
+ text_streamer.end()
163
+
164
+
165
+ # Start the generation thread
166
+ thread = threading.Thread(target=generation_thread_func)
167
+ thread.start()
168
+
169
+ # --- Main thread: Read from queue and update UI ---
170
+ generated_text = ""
171
+ while True:
172
+ try:
173
+ # Get the next text chunk from the queue
174
+ # Use timeout to prevent blocking indefinitely if thread hangs
175
+ chunk = text_queue.get(block=True, timeout=1) # Short timeout OK for slow CPU gen
176
+ if chunk is text_streamer.stop_signal: # Check for end signal (None)
177
+ break
178
+ generated_text += chunk
179
+ response_placeholder.markdown(generated_text + "▌") # Update placeholder
180
+ except queue.Empty:
181
+ # If queue is empty, check if the generation thread is still running
182
+ if not thread.is_alive():
183
+ # Thread finished, but maybe didn't put the stop signal (error?)
184
+ break # Exit loop
185
+ # Otherwise, continue waiting for next chunk
186
+ continue
187
+ except Exception as e:
188
+ st.error(f"Error reading from generation queue: {e}")
189
+ print(f"Error reading from queue: {e}")
190
+ break # Exit loop on queue error
191
+
192
+ # Final update without the cursor
193
+ response_placeholder.markdown(generated_text)
194
+
195
+ # Add the complete assistant response to history *after* generation
196
+ if generated_text: # Only add if something was generated
197
+ st.session_state.messages.append({"role": "assistant", "content": generated_text})
198
+ else:
199
+ # Handle case where generation failed silently in thread or produced nothing
200
+ if not any(m['role'] == 'assistant' and m['content'].startswith("*Error") for m in st.session_state.messages):
201
+ st.warning("Assistant produced no output.", icon="⚠️")
202
+
203
+
204
+ # Wait briefly for the thread to finish if it hasn't already
205
+ thread.join(timeout=5.0) # Longer timeout might be needed if cleanup is slow
206
+
207
+
208
+ except Exception as e:
209
+ st.error(f"Error during generation setup or queue handling: {e}", icon="🔥")
210
+ print(f"Error setting up generation or handling queue: {e}")
211
+ # Add error to chat history for context
212
+ error_message = f"*Error generating response: {e}*"
213
+ if not generated_text: # Add if no text was generated at all
214
+ st.session_state.messages.append({"role": "assistant", "content": error_message})
215
+ response_placeholder.error(f"Error generating response: {e}")
216
+ else: # Append error notice if some text was generated before error
217
+ st.session_state.messages.append({"role": "assistant", "content": generated_text + "\n\n" + error_message})
218
+ response_placeholder.markdown(generated_text + f"\n\n*{error_message}*")