Spaces:
Configuration error
Configuration error
Update app.py
Browse files
app.py
CHANGED
@@ -102,8 +102,9 @@ def interact(user_input, history, interaction_count, model_name):
|
|
102 |
# Determine the device to use (either CUDA if available, or CPU)
|
103 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
104 |
|
105 |
-
#
|
106 |
-
|
|
|
107 |
|
108 |
if interaction_count >= MAX_INTERACTIONS:
|
109 |
user_input += ". Thank you for your questions. Our session is now over. Goodbye!"
|
@@ -117,7 +118,7 @@ def interact(user_input, history, interaction_count, model_name):
|
|
117 |
|
118 |
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
119 |
|
120 |
-
# Move input tensor to the
|
121 |
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device)
|
122 |
chat_history_ids = model.generate(input_ids, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id, temperature=0.1)
|
123 |
response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
|
@@ -137,6 +138,7 @@ def interact(user_input, history, interaction_count, model_name):
|
|
137 |
print(f"Error during interaction: {e}")
|
138 |
raise gr.Error(f"An error occurred during interaction: {str(e)}")
|
139 |
|
|
|
140 |
# Function to send selected story and initial message
|
141 |
def send_selected_story(title, model_name, system_prompt):
|
142 |
global chat_history
|
|
|
102 |
# Determine the device to use (either CUDA if available, or CPU)
|
103 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
104 |
|
105 |
+
# Only move the model to the device if it's not a quantized model
|
106 |
+
if model_name not in quantized_models:
|
107 |
+
model = model.to(device)
|
108 |
|
109 |
if interaction_count >= MAX_INTERACTIONS:
|
110 |
user_input += ". Thank you for your questions. Our session is now over. Goodbye!"
|
|
|
118 |
|
119 |
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
120 |
|
121 |
+
# Move input tensor to the correct device
|
122 |
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device)
|
123 |
chat_history_ids = model.generate(input_ids, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id, temperature=0.1)
|
124 |
response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
|
|
|
138 |
print(f"Error during interaction: {e}")
|
139 |
raise gr.Error(f"An error occurred during interaction: {str(e)}")
|
140 |
|
141 |
+
|
142 |
# Function to send selected story and initial message
|
143 |
def send_selected_story(title, model_name, system_prompt):
|
144 |
global chat_history
|