rodrisouza commited on
Commit
1f73e37
·
verified ·
1 Parent(s): 5be5f75

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -102,8 +102,9 @@ def interact(user_input, history, interaction_count, model_name):
102
  # Determine the device to use (either CUDA if available, or CPU)
103
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
104
 
105
- # Ensure the model is on the correct device
106
- model.to(device)
 
107
 
108
  if interaction_count >= MAX_INTERACTIONS:
109
  user_input += ". Thank you for your questions. Our session is now over. Goodbye!"
@@ -117,7 +118,7 @@ def interact(user_input, history, interaction_count, model_name):
117
 
118
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
119
 
120
- # Move input tensor to the same device as the model
121
  input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device)
122
  chat_history_ids = model.generate(input_ids, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id, temperature=0.1)
123
  response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
@@ -137,6 +138,7 @@ def interact(user_input, history, interaction_count, model_name):
137
  print(f"Error during interaction: {e}")
138
  raise gr.Error(f"An error occurred during interaction: {str(e)}")
139
 
 
140
  # Function to send selected story and initial message
141
  def send_selected_story(title, model_name, system_prompt):
142
  global chat_history
 
102
  # Determine the device to use (either CUDA if available, or CPU)
103
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
104
 
105
+ # Only move the model to the device if it's not a quantized model
106
+ if model_name not in quantized_models:
107
+ model = model.to(device)
108
 
109
  if interaction_count >= MAX_INTERACTIONS:
110
  user_input += ". Thank you for your questions. Our session is now over. Goodbye!"
 
118
 
119
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
120
 
121
+ # Move input tensor to the correct device
122
  input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device)
123
  chat_history_ids = model.generate(input_ids, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id, temperature=0.1)
124
  response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
 
138
  print(f"Error during interaction: {e}")
139
  raise gr.Error(f"An error occurred during interaction: {str(e)}")
140
 
141
+
142
  # Function to send selected story and initial message
143
  def send_selected_story(title, model_name, system_prompt):
144
  global chat_history