hanzla javaid commited on
Commit
7ed8a9a
·
1 Parent(s): ff120ef
Files changed (1) hide show
  1. app.py +40 -25
app.py CHANGED
@@ -1,51 +1,66 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
4
  import spaces
5
 
 
 
 
 
6
  # Dictionary to store loaded models and tokenizers
7
  loaded_models = {}
8
 
9
- # List of available models (update with your preferred models)
10
  models = [
11
  "hanzla/gemma-2b-datascience-instruct-v5",
12
  "hanzla/gemma-2b-datascience-instruct-v4.5"
13
  ]
14
 
15
 
16
- @spaces.GPU
17
  def load_all_models():
18
  """
19
  Pre-loads all models and their tokenizers into memory.
20
  """
21
  for model_name in models:
22
  if model_name not in loaded_models:
23
- print(f"Loading model: {model_name}")
24
- tokenizer = AutoTokenizer.from_pretrained(model_name)
25
- model = AutoModelForCausalLM.from_pretrained(model_name).to("cuda" if torch.cuda.is_available() else "cpu")
26
- loaded_models[model_name] = (model, tokenizer)
27
- return "All models loaded successfully."
28
-
 
 
 
29
 
30
  @spaces.GPU
31
  def get_model_response(model_name, message):
32
  """
33
  Generates a response from the specified model given a user message.
34
  """
35
- model, tokenizer = loaded_models[model_name]
36
- inputs = tokenizer(message, return_tensors="pt").to(model.device)
37
-
38
- # Generate response with appropriate parameters
39
- outputs = model.generate(
40
- **inputs,
41
- max_length=512,
42
- do_sample=True,
43
- top_p=0.95,
44
- top_k=50
45
- )
46
-
47
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
48
- return response
 
 
 
 
 
 
 
 
49
 
50
 
51
  def chat(message, history1, history2, model1, model2):
@@ -99,6 +114,9 @@ def clear_chat():
99
  return [], [], "Votes - 0, 0"
100
 
101
 
 
 
 
102
  with gr.Blocks() as demo:
103
  gr.Markdown("# 🤖 Hugging Face Model Comparison Chat")
104
 
@@ -152,8 +170,5 @@ with gr.Blocks() as demo:
152
  outputs=[chatbot1, chatbot2, vote_text]
153
  )
154
 
155
- # Pre-load all models when the space starts
156
- load_all_models()
157
-
158
  if __name__ == "__main__":
159
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import logging
5
  import spaces
6
 
7
+ # Configure logging
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+
11
  # Dictionary to store loaded models and tokenizers
12
  loaded_models = {}
13
 
14
+ # List of available models (ensure these are correct and accessible)
15
  models = [
16
  "hanzla/gemma-2b-datascience-instruct-v5",
17
  "hanzla/gemma-2b-datascience-instruct-v4.5"
18
  ]
19
 
20
 
 
21
  def load_all_models():
22
  """
23
  Pre-loads all models and their tokenizers into memory.
24
  """
25
  for model_name in models:
26
  if model_name not in loaded_models:
27
+ try:
28
+ logger.info(f"Loading model: {model_name}")
29
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
30
+ model = AutoModelForCausalLM.from_pretrained(model_name).to(
31
+ "cuda" if torch.cuda.is_available() else "cpu")
32
+ loaded_models[model_name] = (model, tokenizer)
33
+ logger.info(f"Successfully loaded {model_name}")
34
+ except Exception as e:
35
+ logger.error(f"Failed to load model {model_name}: {e}")
36
 
37
  @spaces.GPU
38
  def get_model_response(model_name, message):
39
  """
40
  Generates a response from the specified model given a user message.
41
  """
42
+ try:
43
+ model, tokenizer = loaded_models[model_name]
44
+ inputs = tokenizer(message, return_tensors="pt").to(model.device)
45
+
46
+ # Generate response with appropriate parameters
47
+ with torch.no_grad():
48
+ outputs = model.generate(
49
+ **inputs,
50
+ max_length=512,
51
+ do_sample=True,
52
+ top_p=0.95,
53
+ top_k=50
54
+ )
55
+
56
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
57
+ return response
58
+ except KeyError:
59
+ logger.error(f"Model {model_name} not found in loaded_models.")
60
+ return f"Error: Model {model_name} not loaded."
61
+ except Exception as e:
62
+ logger.error(f"Error generating response from {model_name}: {e}")
63
+ return f"Error generating response: {e}"
64
 
65
 
66
  def chat(message, history1, history2, model1, model2):
 
114
  return [], [], "Votes - 0, 0"
115
 
116
 
117
+ # Pre-load all models before building the Gradio interface
118
+ load_all_models()
119
+
120
  with gr.Blocks() as demo:
121
  gr.Markdown("# 🤖 Hugging Face Model Comparison Chat")
122
 
 
170
  outputs=[chatbot1, chatbot2, vote_text]
171
  )
172
 
 
 
 
173
  if __name__ == "__main__":
174
  demo.launch()