yabramuvdi commited on
Commit
f586a0d
·
verified ·
1 Parent(s): bfdc852

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -72
app.py CHANGED
@@ -1,10 +1,7 @@
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
- from typing import Tuple, List, Dict
5
- import numpy as np
6
 
7
- # Select smaller models that are suitable for this task
8
  AVAILABLE_MODELS = {
9
  "distilgpt2": "distilgpt2",
10
  "bloomz-560m": "bigscience/bloomz-560m",
@@ -19,7 +16,6 @@ class TextGenerator:
19
  self.tokenizer = None
20
 
21
  def load_model(self, model_name: str) -> str:
22
- """Load the selected model and tokenizer"""
23
  try:
24
  self.model = AutoModelForCausalLM.from_pretrained(AVAILABLE_MODELS[model_name])
25
  self.tokenizer = AutoTokenizer.from_pretrained(AVAILABLE_MODELS[model_name])
@@ -27,8 +23,7 @@ class TextGenerator:
27
  except Exception as e:
28
  return f"Error loading model: {str(e)}"
29
 
30
- def get_next_token_predictions(self, text: str, top_k: int = 10) -> Tuple[List[str], List[float]]:
31
- """Get predictions for the next token"""
32
  if not self.model or not self.tokenizer:
33
  return [], []
34
 
@@ -40,12 +35,12 @@ class TextGenerator:
40
 
41
  top_k_probs, top_k_indices = torch.topk(probs, top_k)
42
  top_k_tokens = [self.tokenizer.decode([idx.item()]) for idx in top_k_indices]
43
- top_k_probs = top_k_probs.tolist()
44
 
45
- return top_k_tokens, top_k_probs
46
 
47
- def format_predictions(tokens: List[str], probs: List[float]) -> str:
48
- """Format the predictions for display"""
 
49
  if not tokens or not probs:
50
  return "No predictions available"
51
 
@@ -54,83 +49,40 @@ def format_predictions(tokens: List[str], probs: List[float]) -> str:
54
  formatted += f"'{token}' : {prob:.4f}\n"
55
  return formatted
56
 
57
- generator = TextGenerator()
58
-
59
- def update_output(model_name: str, text: str, custom_token: str, selected_token: str) -> Tuple[str, str, str, Dict, str]:
60
- """Update the interface based on user interactions"""
61
  output = text
62
 
63
- # Load model if it changed
64
  if not generator.model or generator.model.name_or_path != AVAILABLE_MODELS[model_name]:
65
  load_message = generator.load_model(model_name)
66
  if "Error" in load_message:
67
  return text, "", "", gr.update(choices=[]), load_message
68
 
69
- # Add custom token or selected token
70
  if custom_token:
71
  output += custom_token
72
  elif selected_token:
73
  output += selected_token.strip("'")
74
 
75
- # Get new predictions
76
  tokens, probs = generator.get_next_token_predictions(output)
77
  predictions = format_predictions(tokens, probs)
78
-
79
- # Update dropdown choices
80
  token_choices = [f"'{token}'" for token in tokens]
81
 
82
  return output, "", "", gr.update(choices=token_choices), predictions
83
 
84
- # Create the interface
85
- demo = gr.Blocks(title="Interactive Text Generation")
86
-
87
- with demo:
88
- gr.Markdown("# Interactive Text Generation")
89
- gr.Markdown("Generate text by selecting predicted tokens or writing your own.")
90
-
91
- with gr.Row():
92
- model_dropdown = gr.Dropdown(
93
- choices=list(AVAILABLE_MODELS.keys()),
94
- value="distilgpt2",
95
- label="Select Model"
96
- )
97
-
98
- with gr.Row():
99
- text_input = gr.Textbox(
100
- lines=5,
101
- label="Generated Text",
102
- placeholder="Start typing or select a token..."
103
- )
104
-
105
- with gr.Row():
106
- custom_token = gr.Textbox(
107
- label="Custom Token",
108
- placeholder="Type your own token..."
109
- )
110
- token_dropdown = gr.Dropdown(
111
- choices=[],
112
- label="Select from predicted tokens"
113
- )
114
-
115
- with gr.Row():
116
- predictions_output = gr.Textbox(
117
- label="Predictions",
118
- lines=12
119
- )
120
-
121
- with gr.Row():
122
- status_output = gr.Textbox(
123
- label="Status",
124
- lines=1
125
- )
126
-
127
- # Update when model changes or token is added
128
- for trigger in [model_dropdown, custom_token, token_dropdown]:
129
- trigger.change(
130
- fn=update_output,
131
- inputs=[model_dropdown, text_input, custom_token, token_dropdown],
132
- outputs=[text_input, custom_token, token_dropdown, token_dropdown, predictions_output]
133
- )
134
-
135
- # For Hugging Face Spaces, we just need to expose the demo
136
- demo.launch(share=True)
 
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
 
 
4
 
 
5
  AVAILABLE_MODELS = {
6
  "distilgpt2": "distilgpt2",
7
  "bloomz-560m": "bigscience/bloomz-560m",
 
16
  self.tokenizer = None
17
 
18
  def load_model(self, model_name: str) -> str:
 
19
  try:
20
  self.model = AutoModelForCausalLM.from_pretrained(AVAILABLE_MODELS[model_name])
21
  self.tokenizer = AutoTokenizer.from_pretrained(AVAILABLE_MODELS[model_name])
 
23
  except Exception as e:
24
  return f"Error loading model: {str(e)}"
25
 
26
+ def get_next_token_predictions(self, text: str, top_k: int = 10):
 
27
  if not self.model or not self.tokenizer:
28
  return [], []
29
 
 
35
 
36
  top_k_probs, top_k_indices = torch.topk(probs, top_k)
37
  top_k_tokens = [self.tokenizer.decode([idx.item()]) for idx in top_k_indices]
 
38
 
39
+ return top_k_tokens, top_k_probs.tolist()
40
 
41
+ generator = TextGenerator()
42
+
43
+ def format_predictions(tokens, probs):
44
  if not tokens or not probs:
45
  return "No predictions available"
46
 
 
49
  formatted += f"'{token}' : {prob:.4f}\n"
50
  return formatted
51
 
52
+ def update_output(model_name, text, custom_token, selected_token):
 
 
 
53
  output = text
54
 
 
55
  if not generator.model or generator.model.name_or_path != AVAILABLE_MODELS[model_name]:
56
  load_message = generator.load_model(model_name)
57
  if "Error" in load_message:
58
  return text, "", "", gr.update(choices=[]), load_message
59
 
 
60
  if custom_token:
61
  output += custom_token
62
  elif selected_token:
63
  output += selected_token.strip("'")
64
 
 
65
  tokens, probs = generator.get_next_token_predictions(output)
66
  predictions = format_predictions(tokens, probs)
 
 
67
  token_choices = [f"'{token}'" for token in tokens]
68
 
69
  return output, "", "", gr.update(choices=token_choices), predictions
70
 
71
+ demo = gr.Interface(
72
+ fn=update_output,
73
+ inputs=[
74
+ gr.Dropdown(choices=list(AVAILABLE_MODELS.keys()), value="distilgpt2", label="Select Model"),
75
+ gr.Textbox(lines=5, label="Generated Text", placeholder="Start typing or select a token..."),
76
+ gr.Textbox(label="Custom Token", placeholder="Type your own token..."),
77
+ gr.Dropdown(choices=[], label="Select from predicted tokens")
78
+ ],
79
+ outputs=[
80
+ gr.Textbox(lines=5, label="Generated Text"),
81
+ gr.Textbox(label="Custom Token"),
82
+ gr.Textbox(label="Selected Token"),
83
+ gr.Dropdown(label="Predicted Tokens"),
84
+ gr.Textbox(lines=12, label="Predictions")
85
+ ],
86
+ title="Interactive Text Generation",
87
+ description="Generate text by selecting predicted tokens or writing your own."
88
+ )