Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -161,27 +161,51 @@ def load_model(model_choice: str, progress=gr.Progress()):
|
|
161 |
print(f"Loading model {model_id}...")
|
162 |
|
163 |
# Load model with appropriate settings for Colab
|
164 |
-
model = AutoModelForCausalLM.from_pretrained(
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
)
|
172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
progress(0.8, desc="Creating pipeline...")
|
174 |
print("Creating text generation pipeline...")
|
175 |
|
176 |
# Create pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
pipe = pipeline(
|
178 |
"text-generation",
|
179 |
model=model,
|
180 |
tokenizer=tokenizer,
|
181 |
-
device_map="auto",
|
182 |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
183 |
)
|
184 |
-
|
185 |
# Store globally
|
186 |
current_model = model
|
187 |
current_tokenizer = tokenizer
|
@@ -733,7 +757,7 @@ def demo():
|
|
733 |
vector_db = gr.State()
|
734 |
qa_chain = gr.State()
|
735 |
|
736 |
-
gr.HTML("<center><h1>π Enhanced RAG CSV Chatbot
|
737 |
gr.HTML("<center><p>Upload CSV files and chat with your data using powerful local language models</p></center>")
|
738 |
|
739 |
with gr.Row():
|
|
|
161 |
print(f"Loading model {model_id}...")
|
162 |
|
163 |
# Load model with appropriate settings for Colab
|
164 |
+
# model = AutoModelForCausalLM.from_pretrained(
|
165 |
+
# model_id,
|
166 |
+
# device_map="auto",
|
167 |
+
# trust_remote_code=True,
|
168 |
+
# load_in_4bit=True, # Use 4-bit quantization for memory efficiency
|
169 |
+
# torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
170 |
+
# low_cpu_mem_usage=True
|
171 |
+
# )
|
172 |
+
if torch.cuda.is_available():
|
173 |
+
# On GPU: use 4-bit quantization
|
174 |
+
model = AutoModelForCausalLM.from_pretrained(
|
175 |
+
model_id,
|
176 |
+
device_map="auto",
|
177 |
+
trust_remote_code=True,
|
178 |
+
load_in_4bit=True,
|
179 |
+
torch_dtype=torch.float16,
|
180 |
+
low_cpu_mem_usage=True
|
181 |
+
)
|
182 |
+
else:
|
183 |
+
# On CPU: do NOT use 4-bit quantization
|
184 |
+
model = AutoModelForCausalLM.from_pretrained(
|
185 |
+
model_id,
|
186 |
+
device_map="cpu",
|
187 |
+
trust_remote_code=True,
|
188 |
+
torch_dtype=torch.float32,
|
189 |
+
low_cpu_mem_usage=True
|
190 |
+
)
|
191 |
progress(0.8, desc="Creating pipeline...")
|
192 |
print("Creating text generation pipeline...")
|
193 |
|
194 |
# Create pipeline
|
195 |
+
# pipe = pipeline(
|
196 |
+
# "text-generation",
|
197 |
+
# model=model,
|
198 |
+
# tokenizer=tokenizer,
|
199 |
+
# device_map="auto",
|
200 |
+
# torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
201 |
+
# )
|
202 |
pipe = pipeline(
|
203 |
"text-generation",
|
204 |
model=model,
|
205 |
tokenizer=tokenizer,
|
206 |
+
device_map="auto" if torch.cuda.is_available() else "cpu",
|
207 |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
208 |
)
|
|
|
209 |
# Store globally
|
210 |
current_model = model
|
211 |
current_tokenizer = tokenizer
|
|
|
757 |
vector_db = gr.State()
|
758 |
qa_chain = gr.State()
|
759 |
|
760 |
+
gr.HTML("<center><h1>π Enhanced RAG CSV Chatbot</h1></center>")
|
761 |
gr.HTML("<center><p>Upload CSV files and chat with your data using powerful local language models</p></center>")
|
762 |
|
763 |
with gr.Row():
|