sandeep-huggingface commited on
Commit
0bdd08d
Β·
verified Β·
1 Parent(s): 8d8bee7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -12
app.py CHANGED
@@ -161,27 +161,51 @@ def load_model(model_choice: str, progress=gr.Progress()):
161
  print(f"Loading model {model_id}...")
162
 
163
  # Load model with appropriate settings for Colab
164
- model = AutoModelForCausalLM.from_pretrained(
165
- model_id,
166
- device_map="auto",
167
- trust_remote_code=True,
168
- load_in_4bit=True, # Use 4-bit quantization for memory efficiency
169
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
170
- low_cpu_mem_usage=True
171
- )
172
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  progress(0.8, desc="Creating pipeline...")
174
  print("Creating text generation pipeline...")
175
 
176
  # Create pipeline
 
 
 
 
 
 
 
177
  pipe = pipeline(
178
  "text-generation",
179
  model=model,
180
  tokenizer=tokenizer,
181
- device_map="auto",
182
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
183
  )
184
-
185
  # Store globally
186
  current_model = model
187
  current_tokenizer = tokenizer
@@ -733,7 +757,7 @@ def demo():
733
  vector_db = gr.State()
734
  qa_chain = gr.State()
735
 
736
- gr.HTML("<center><h1>πŸ“Š Enhanced RAG CSV Chatbot with Local Transformers</h1></center>")
737
  gr.HTML("<center><p>Upload CSV files and chat with your data using powerful local language models</p></center>")
738
 
739
  with gr.Row():
 
161
  print(f"Loading model {model_id}...")
162
 
163
  # Load model with appropriate settings for Colab
164
+ # model = AutoModelForCausalLM.from_pretrained(
165
+ # model_id,
166
+ # device_map="auto",
167
+ # trust_remote_code=True,
168
+ # load_in_4bit=True, # Use 4-bit quantization for memory efficiency
169
+ # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
170
+ # low_cpu_mem_usage=True
171
+ # )
172
+ if torch.cuda.is_available():
173
+ # On GPU: use 4-bit quantization
174
+ model = AutoModelForCausalLM.from_pretrained(
175
+ model_id,
176
+ device_map="auto",
177
+ trust_remote_code=True,
178
+ load_in_4bit=True,
179
+ torch_dtype=torch.float16,
180
+ low_cpu_mem_usage=True
181
+ )
182
+ else:
183
+ # On CPU: do NOT use 4-bit quantization
184
+ model = AutoModelForCausalLM.from_pretrained(
185
+ model_id,
186
+ device_map="cpu",
187
+ trust_remote_code=True,
188
+ torch_dtype=torch.float32,
189
+ low_cpu_mem_usage=True
190
+ )
191
  progress(0.8, desc="Creating pipeline...")
192
  print("Creating text generation pipeline...")
193
 
194
  # Create pipeline
195
+ # pipe = pipeline(
196
+ # "text-generation",
197
+ # model=model,
198
+ # tokenizer=tokenizer,
199
+ # device_map="auto",
200
+ # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
201
+ # )
202
  pipe = pipeline(
203
  "text-generation",
204
  model=model,
205
  tokenizer=tokenizer,
206
+ device_map="auto" if torch.cuda.is_available() else "cpu",
207
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
208
  )
 
209
  # Store globally
210
  current_model = model
211
  current_tokenizer = tokenizer
 
757
  vector_db = gr.State()
758
  qa_chain = gr.State()
759
 
760
+ gr.HTML("<center><h1>πŸ“Š Enhanced RAG CSV Chatbot</h1></center>")
761
  gr.HTML("<center><p>Upload CSV files and chat with your data using powerful local language models</p></center>")
762
 
763
  with gr.Row():