yabramuvdi commited on
Commit
310d018
·
verified ·
1 Parent(s): 93a3f9a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -37
app.py CHANGED
@@ -1,5 +1,10 @@
1
  import os
2
- # Handle Spaces GPU
 
 
 
 
 
3
  if os.environ.get("SPACES_ZERO_GPU") is not None:
4
  import spaces
5
  else:
@@ -14,11 +19,6 @@ else:
14
  def fake_gpu():
15
  pass
16
 
17
- import numpy as np
18
- import torch
19
- import gradio as gr
20
- from transformers import AutoModelForCausalLM, AutoTokenizer
21
-
22
  # Available models
23
  AVAILABLE_MODELS = {
24
  "distilgpt2": "distilgpt2",
@@ -28,58 +28,67 @@ AVAILABLE_MODELS = {
28
  "pythia-160m": "EleutherAI/pythia-160m"
29
  }
30
 
31
- # Initialize model and tokenizer globally
32
  current_model = None
33
  current_tokenizer = None
34
  current_model_name = None
 
35
 
36
  def load_model(model_name):
 
37
  global current_model, current_tokenizer, current_model_name
38
  if current_model_name != model_name:
39
- current_model = AutoModelForCausalLM.from_pretrained(AVAILABLE_MODELS[model_name])
40
  current_tokenizer = AutoTokenizer.from_pretrained(AVAILABLE_MODELS[model_name])
41
  current_model_name = model_name
42
 
 
 
 
43
  def get_next_token_predictions(text, model_name, top_k=10):
 
44
  global current_model, current_tokenizer
45
 
46
- # Load model if needed
47
  if current_model_name != model_name:
48
  load_model(model_name)
49
 
50
- # Get predictions
51
- inputs = current_tokenizer(text, return_tensors="pt")
52
  with torch.no_grad():
53
  outputs = current_model(**inputs)
54
  logits = outputs.logits[0, -1, :]
55
  probs = torch.nn.functional.softmax(logits, dim=-1)
56
-
57
  top_k_probs, top_k_indices = torch.topk(probs, k=top_k)
58
  top_k_tokens = [current_tokenizer.decode([idx.item()]) for idx in top_k_indices]
59
 
60
- return top_k_tokens, top_k_probs.tolist()
61
 
62
- def predict_next_token(model_name, text, custom_token=""):
63
- # Add custom token if provided
64
  if custom_token:
65
  text += custom_token
66
-
67
- # Get predictions
68
- tokens, probs = get_next_token_predictions(text, model_name)
69
-
70
- # Format predictions
71
- predictions = "\n".join([f"'{token}' : {prob:.4f}" for token, prob in zip(tokens, probs)])
72
 
73
  return gr.update(choices=[f"'{t}'" for t in tokens]), predictions
74
 
75
- # Create the interface
 
 
 
 
 
 
76
  with gr.Blocks() as demo:
77
- gr.Markdown("# Interactive Text Generation with Transformer Models")
78
-
79
- gr.Markdown("""
80
- This application allows you to interactively generate text using various transformer models.
81
- Select a model, enter your text, and click predict to see the possible next tokens and their probabilities.
82
- """)
83
 
84
  with gr.Row():
85
  model_dropdown = gr.Dropdown(
@@ -91,31 +100,47 @@ with gr.Blocks() as demo:
91
  with gr.Row():
92
  text_input = gr.Textbox(
93
  lines=5,
94
- label="Text",
95
  placeholder="Type your text here...",
96
  value="The quick brown fox"
97
  )
98
 
 
 
 
 
 
 
 
 
 
99
  with gr.Row():
100
  predict_button = gr.Button("Predict")
101
-
102
  with gr.Row():
103
  token_dropdown = gr.Dropdown(
104
- label="Predicted tokens",
105
  choices=[]
106
  )
107
-
 
108
  with gr.Row():
109
  predictions_output = gr.Textbox(
110
  lines=10,
111
- label="Token probabilities"
112
  )
113
-
114
- # Set up predict button event handler
115
  predict_button.click(
116
  predict_next_token,
117
- inputs=[model_dropdown, text_input],
118
  outputs=[token_dropdown, predictions_output]
119
  )
120
 
121
- demo.queue().launch()
 
 
 
 
 
 
 
1
  import os
2
+ import numpy as np
3
+ import torch
4
+ import gradio as gr
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+
7
+ # Handle Hugging Face Spaces GPU
8
  if os.environ.get("SPACES_ZERO_GPU") is not None:
9
  import spaces
10
  else:
 
19
  def fake_gpu():
20
  pass
21
 
 
 
 
 
 
22
  # Available models
23
  AVAILABLE_MODELS = {
24
  "distilgpt2": "distilgpt2",
 
28
  "pythia-160m": "EleutherAI/pythia-160m"
29
  }
30
 
31
+ # Initialize model and tokenizer
32
  current_model = None
33
  current_tokenizer = None
34
  current_model_name = None
35
+ device = "cuda" if torch.cuda.is_available() else "cpu"
36
 
37
  def load_model(model_name):
38
+ """Load the selected model and tokenizer."""
39
  global current_model, current_tokenizer, current_model_name
40
  if current_model_name != model_name:
41
+ current_model = AutoModelForCausalLM.from_pretrained(AVAILABLE_MODELS[model_name]).to(device)
42
  current_tokenizer = AutoTokenizer.from_pretrained(AVAILABLE_MODELS[model_name])
43
  current_model_name = model_name
44
 
45
+ # Load the default model at startup
46
+ load_model("distilgpt2")
47
+
48
  def get_next_token_predictions(text, model_name, top_k=10):
49
+ """Generate the next token predictions with their probabilities."""
50
  global current_model, current_tokenizer
51
 
52
+ # Load the model if it has changed
53
  if current_model_name != model_name:
54
  load_model(model_name)
55
 
56
+ inputs = current_tokenizer(text, return_tensors="pt").to(device)
57
+
58
  with torch.no_grad():
59
  outputs = current_model(**inputs)
60
  logits = outputs.logits[0, -1, :]
61
  probs = torch.nn.functional.softmax(logits, dim=-1)
62
+
63
  top_k_probs, top_k_indices = torch.topk(probs, k=top_k)
64
  top_k_tokens = [current_tokenizer.decode([idx.item()]) for idx in top_k_indices]
65
 
66
+ return top_k_tokens, top_k_probs.cpu().tolist()
67
 
68
+ def predict_next_token(model_name, text, top_k, custom_token=""):
69
+ """Get predictions and update the UI."""
70
  if custom_token:
71
  text += custom_token
72
+
73
+ tokens, probs = get_next_token_predictions(text, model_name, top_k)
74
+
75
+ predictions = "\n".join([f"'{token}': {prob:.4f}" for token, prob in zip(tokens, probs)])
 
 
76
 
77
  return gr.update(choices=[f"'{t}'" for t in tokens]), predictions
78
 
79
+ def append_selected_token(text, selected_token):
80
+ """Append selected token from dropdown to the text input."""
81
+ if selected_token:
82
+ text += f" {selected_token.strip('\'')}"
83
+ return text
84
+
85
+ # Create the UI
86
  with gr.Blocks() as demo:
87
+ gr.Markdown("# 🔥 Interactive Text Prediction with Transformers")
88
+ gr.Markdown(
89
+ "This application lets you interactively generate text using multiple transformer models. "
90
+ "Choose a model, type your text, and explore token predictions."
91
+ )
 
92
 
93
  with gr.Row():
94
  model_dropdown = gr.Dropdown(
 
100
  with gr.Row():
101
  text_input = gr.Textbox(
102
  lines=5,
103
+ label="Input Text",
104
  placeholder="Type your text here...",
105
  value="The quick brown fox"
106
  )
107
 
108
+ with gr.Row():
109
+ top_k_slider = gr.Slider(
110
+ minimum=1,
111
+ maximum=20,
112
+ value=10,
113
+ step=1,
114
+ label="Top-k Predictions"
115
+ )
116
+
117
  with gr.Row():
118
  predict_button = gr.Button("Predict")
119
+
120
  with gr.Row():
121
  token_dropdown = gr.Dropdown(
122
+ label="Predicted Tokens",
123
  choices=[]
124
  )
125
+ append_button = gr.Button("Append Token")
126
+
127
  with gr.Row():
128
  predictions_output = gr.Textbox(
129
  lines=10,
130
+ label="Token Probabilities"
131
  )
132
+
133
+ # Button click events
134
  predict_button.click(
135
  predict_next_token,
136
+ inputs=[model_dropdown, text_input, top_k_slider],
137
  outputs=[token_dropdown, predictions_output]
138
  )
139
 
140
+ append_button.click(
141
+ append_selected_token,
142
+ inputs=[text_input, token_dropdown],
143
+ outputs=text_input
144
+ )
145
+
146
+ demo.queue().launch()