ayush0504 commited on
Commit
0c92128
·
verified ·
1 Parent(s): 0523303

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +183 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from peft import AutoPeftModelForCausalLM
4
+ from transformers import AutoTokenizer, TextStreamer
5
+ import io
6
+ import sys
7
+ import threading
8
+ import time
9
+ import queue # Import the queue module
10
+
11
+ # --- Configuration ---
12
+ DEFAULT_MODEL_PATH = "lora_model" # Or your default path
13
+ DEFAULT_LOAD_IN_4BIT = True
14
+
15
+ # --- Page Configuration ---
16
+ st.set_page_config(page_title="Fine-tuned LLM Chat Interface", layout="wide")
17
+ st.title("Fine-tuned LLM Chat Interface")
18
+
19
+ # --- Model Loading (Cached) ---
20
+ @st.cache_resource(show_spinner="Loading model and tokenizer...")
21
+ def load_model_and_tokenizer(model_path, load_in_4bit):
22
+ """Loads the PEFT model and tokenizer."""
23
+ try:
24
+ torch_dtype = torch.bfloat16 if load_in_4bit else torch.float16 # bfloat16 often better for 4-bit
25
+
26
+ model = AutoPeftModelForCausalLM.from_pretrained(
27
+ model_path,
28
+ torch_dtype=torch_dtype,
29
+ load_in_4bit=load_in_4bit,
30
+ device_map="auto",
31
+ )
32
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
33
+ model.eval()
34
+ print("Model and tokenizer loaded successfully.")
35
+ return model, tokenizer
36
+ except Exception as e:
37
+ st.error(f"Error loading model from path '{model_path}': {e}", icon="🚨")
38
+ print(f"Error loading model: {e}")
39
+ return None, None
40
+
41
+ # --- Custom Streamer Class (Modified for Queue) ---
42
+ class QueueStreamer(TextStreamer):
43
+ def __init__(self, tokenizer, skip_prompt, q):
44
+ super().__init__(tokenizer, skip_prompt=skip_prompt)
45
+ self.queue = q
46
+ self.stop_signal = None # Can be used if needed, but queue is primary
47
+
48
+ def on_finalized_text(self, text: str, stream_end: bool = False):
49
+ """Puts the text onto the queue."""
50
+ self.queue.put(text)
51
+ if stream_end:
52
+ self.end()
53
+
54
+ def end(self):
55
+ """Signals the end of generation by putting None in the queue."""
56
+ self.queue.put(self.stop_signal) # Put None (or a specific sentinel)
57
+
58
+
59
+ # --- Sidebar for Settings ---
60
+ with st.sidebar:
61
+ st.header("Model Configuration")
62
+ st.info(f"Model loaded on startup: `{DEFAULT_MODEL_PATH}`, 4-bit: `{DEFAULT_LOAD_IN_4BIT}`.")
63
+
64
+ st.header("Generation Settings")
65
+ temperature = st.slider("Temperature", min_value=0.0, max_value=2.0, value=0.7, step=0.05)
66
+ min_p = st.slider("Min P", min_value=0.01, max_value=1.0, value=0.1, step=0.01)
67
+ max_tokens = st.slider("Max New Tokens", min_value=50, max_value=2048, value=512, step=50)
68
+
69
+ if st.button("Clear Chat History"):
70
+ st.session_state.messages = []
71
+ st.rerun() # Rerun to clear display immediately
72
+
73
+
74
+ # --- Load Model (runs only once on first run or if cache is cleared) ---
75
+ model, tokenizer = load_model_and_tokenizer(DEFAULT_MODEL_PATH, DEFAULT_LOAD_IN_4BIT)
76
+
77
+ # --- Initialize Session State ---
78
+ if "messages" not in st.session_state:
79
+ st.session_state.messages = []
80
+
81
+ # --- Main Chat Interface ---
82
+ if model is None or tokenizer is None:
83
+ st.error("Model loading failed. Please check the path and logs. Cannot proceed.")
84
+ st.stop()
85
+
86
+ # Display conversation history
87
+ for message in st.session_state.messages:
88
+ with st.chat_message(message["role"]):
89
+ st.markdown(message["content"])
90
+
91
+ # Handle user input
92
+ user_input = st.chat_input("Ask the fine-tuned model...")
93
+
94
+ if user_input:
95
+ # Add user message to history and display it
96
+ st.session_state.messages.append({"role": "user", "content": user_input})
97
+ with st.chat_message("user"):
98
+ st.markdown(user_input)
99
+
100
+ # Prepare for model response
101
+ with st.chat_message("assistant"):
102
+ response_placeholder = st.empty()
103
+ text_queue = queue.Queue() # Create a queue for this specific response
104
+ # Initialize the modified streamer
105
+ text_streamer = QueueStreamer(tokenizer, skip_prompt=True, q=text_queue)
106
+
107
+ # Prepare input for the model
108
+ messages_for_model = st.session_state.messages
109
+
110
+ try:
111
+ if tokenizer.chat_template:
112
+ inputs = tokenizer.apply_chat_template(
113
+ messages_for_model,
114
+ tokenize=True,
115
+ add_generation_prompt=True,
116
+ return_tensors="pt"
117
+ ).to(model.device)
118
+ else:
119
+ prompt_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages_for_model]) + "\nassistant:"
120
+ inputs = tokenizer(prompt_text, return_tensors="pt").input_ids.to(model.device)
121
+
122
+ # Generation arguments
123
+ generation_kwargs = dict(
124
+ input_ids=inputs,
125
+ streamer=text_streamer, # Use the QueueStreamer
126
+ max_new_tokens=max_tokens,
127
+ use_cache=True,
128
+ temperature=temperature if temperature > 0 else None,
129
+ top_p=None,
130
+ min_p=min_p,
131
+ do_sample=True if temperature > 0 else False,
132
+ eos_token_id=tokenizer.eos_token_id,
133
+ pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id
134
+ )
135
+
136
+ # Define the target function for the thread
137
+ def generation_thread_func():
138
+ try:
139
+ # Run generation in the background thread
140
+ model.generate(**generation_kwargs)
141
+ except Exception as e:
142
+ # If error occurs in thread, signal stop and maybe log
143
+ print(f"Error in generation thread: {e}")
144
+ text_streamer.end() # Ensure the queue loop terminates
145
+
146
+ # Start the generation thread
147
+ thread = threading.Thread(target=generation_thread_func)
148
+ thread.start()
149
+
150
+ # --- Main thread: Read from queue and update UI ---
151
+ generated_text = ""
152
+ while True:
153
+ try:
154
+ # Get the next text chunk from the queue
155
+ # Use timeout to prevent blocking indefinitely if thread hangs
156
+ chunk = text_queue.get(block=True, timeout=1)
157
+ if chunk is text_streamer.stop_signal: # Check for end signal (None)
158
+ break
159
+ generated_text += chunk
160
+ response_placeholder.markdown(generated_text + "▌") # Update placeholder
161
+ except queue.Empty:
162
+ # If queue is empty, check if the generation thread is still running
163
+ if not thread.is_alive():
164
+ # Thread finished, but maybe didn't put the stop signal (error?)
165
+ break # Exit loop
166
+ # Otherwise, continue waiting for next chunk
167
+ continue
168
+
169
+ # Final update without the cursor
170
+ response_placeholder.markdown(generated_text)
171
+
172
+ # Add the complete assistant response to history *after* generation
173
+ st.session_state.messages.append({"role": "assistant", "content": generated_text})
174
+
175
+ # Wait briefly for the thread to finish if it hasn't already
176
+ thread.join(timeout=2.0)
177
+
178
+
179
+ except Exception as e:
180
+ st.error(f"Error during generation setup or queue handling: {e}", icon="🔥")
181
+ print(f"Error setting up generation or handling queue: {e}")
182
+ st.session_state.messages.append({"role": "assistant", "content": f"*Error generating response: {e}*"})
183
+ response_placeholder.error(f"Error generating response: {e}")
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch==2.6.0+cu124
2
+ transformers==4.49.0
3
+ peft==0.14.0
4
+ streamlit==1.37.1
5
+ accelerate==1.1.1
6
+ bitsandbytes==0.45.3
7
+ sentencepiece==0.2.0
8
+ protobuf==5.28.3