rodrisouza commited on
Commit
f9160fd
·
verified ·
1 Parent(s): eabbb32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -8
app.py CHANGED
@@ -4,7 +4,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import pandas as pd
5
  from datetime import datetime, timedelta, timezone
6
  import torch
7
- from config import hugging_face_token, init_google_sheets_client, models, default_model_name, user_names, google_sheets_name, MAX_INTERACTIONS
8
  import spaces
9
 
10
  # Hack for ZeroGPU
@@ -48,14 +48,28 @@ def load_model(model_name):
48
  del model
49
  torch.cuda.empty_cache()
50
 
51
- tokenizer = AutoTokenizer.from_pretrained(models[model_name], padding_side='left', token=hugging_face_token, trust_remote_code=True)
 
 
 
 
 
52
 
53
  # Ensure the padding token is set
54
  if tokenizer.pad_token is None:
55
  tokenizer.pad_token = tokenizer.eos_token
56
  tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
57
 
58
- model = AutoModelForCausalLM.from_pretrained(models[model_name], token=hugging_face_token, trust_remote_code=True).to("cuda")
 
 
 
 
 
 
 
 
 
59
  selected_model = model_name
60
  except Exception as e:
61
  print(f"Error loading model {model_name}: {e}")
@@ -70,12 +84,18 @@ chat_history = []
70
 
71
  # Function to handle interaction with model
72
  @spaces.GPU
73
- def interact(user_input, history, interaction_count):
74
  global tokenizer, model
75
  try:
76
  if tokenizer is None or model is None:
77
  raise ValueError("Tokenizer or model is not initialized.")
78
 
 
 
 
 
 
 
79
  if interaction_count >= MAX_INTERACTIONS:
80
  user_input += ". Thank you for your questions. Our session is now over. Goodbye!"
81
 
@@ -88,8 +108,8 @@ def interact(user_input, history, interaction_count):
88
 
89
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
90
 
91
- # Generate response using selected model
92
- input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to("cuda")
93
  chat_history_ids = model.generate(input_ids, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id, temperature=0.1)
94
  response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
95
 
@@ -131,7 +151,7 @@ Here is the story:
131
  chat_history = [] # Reset chat history
132
  chat_history.append({"role": "system", "content": combined_message})
133
  question_prompt = "Please ask a simple question about the story to encourage interaction."
134
- _, formatted_history, chat_history, interaction_count = interact(question_prompt, chat_history, interaction_count)
135
 
136
  return formatted_history, chat_history, gr.update(value=[]), story["story"]
137
  else:
@@ -182,6 +202,9 @@ def load_user_guide():
182
  with open('user_guide.txt', 'r') as file:
183
  return file.read()
184
 
 
 
 
185
  # Create the chat interface using Gradio Blocks
186
  with gr.Blocks() as demo:
187
  with gr.Tabs():
@@ -190,7 +213,7 @@ with gr.Blocks() as demo:
190
 
191
  gr.Markdown("## Context")
192
  with gr.Group():
193
- model_dropdown = gr.Dropdown(choices=list(models.keys()), label="Select Model", value=selected_model)
194
  user_dropdown = gr.Dropdown(choices=user_names, label="Select User Name")
195
  initial_story = stories[0]["title"] if stories else None
196
  story_dropdown = gr.Dropdown(choices=[story["title"] for story in stories], label="Select Story", value=initial_story)
 
4
  import pandas as pd
5
  from datetime import datetime, timedelta, timezone
6
  import torch
7
+ from config import hugging_face_token, init_google_sheets_client, models, quantized_models, default_model_name, user_names, google_sheets_name, MAX_INTERACTIONS
8
  import spaces
9
 
10
  # Hack for ZeroGPU
 
48
  del model
49
  torch.cuda.empty_cache()
50
 
51
+ tokenizer = AutoTokenizer.from_pretrained(
52
+ models[model_name],
53
+ padding_side='left',
54
+ token=hugging_face_token,
55
+ trust_remote_code=True
56
+ )
57
 
58
  # Ensure the padding token is set
59
  if tokenizer.pad_token is None:
60
  tokenizer.pad_token = tokenizer.eos_token
61
  tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
62
 
63
+ model = AutoModelForCausalLM.from_pretrained(
64
+ models[model_name],
65
+ token=hugging_face_token,
66
+ trust_remote_code=True
67
+ )
68
+
69
+ # Only move to CUDA if it's not a quantized model
70
+ if model_name not in quantized_models:
71
+ model = model.to("cuda")
72
+
73
  selected_model = model_name
74
  except Exception as e:
75
  print(f"Error loading model {model_name}: {e}")
 
84
 
85
  # Function to handle interaction with model
86
  @spaces.GPU
87
+ def interact(user_input, history, interaction_count, model_name):
88
  global tokenizer, model
89
  try:
90
  if tokenizer is None or model is None:
91
  raise ValueError("Tokenizer or model is not initialized.")
92
 
93
+ # Determine the device to use (either CUDA if available, or CPU)
94
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
95
+
96
+ # Ensure the model is on the correct device
97
+ model.to(device)
98
+
99
  if interaction_count >= MAX_INTERACTIONS:
100
  user_input += ". Thank you for your questions. Our session is now over. Goodbye!"
101
 
 
108
 
109
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
110
 
111
+ # Move input tensor to the same device as the model
112
+ input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device)
113
  chat_history_ids = model.generate(input_ids, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id, temperature=0.1)
114
  response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
115
 
 
151
  chat_history = [] # Reset chat history
152
  chat_history.append({"role": "system", "content": combined_message})
153
  question_prompt = "Please ask a simple question about the story to encourage interaction."
154
+ _, formatted_history, chat_history, interaction_count = interact(question_prompt, chat_history, interaction_count, model_name)
155
 
156
  return formatted_history, chat_history, gr.update(value=[]), story["story"]
157
  else:
 
202
  with open('user_guide.txt', 'r') as file:
203
  return file.read()
204
 
205
+ # Combine both model dictionaries
206
+ all_models = {**models, **quantized_models}
207
+
208
  # Create the chat interface using Gradio Blocks
209
  with gr.Blocks() as demo:
210
  with gr.Tabs():
 
213
 
214
  gr.Markdown("## Context")
215
  with gr.Group():
216
+ model_dropdown = gr.Dropdown(choices=list(all_models.keys()), label="Select Model", value=default_model_name)
217
  user_dropdown = gr.Dropdown(choices=user_names, label="Select User Name")
218
  initial_story = stories[0]["title"] if stories else None
219
  story_dropdown = gr.Dropdown(choices=[story["title"] for story in stories], label="Select Story", value=initial_story)