yabramuvdi commited on
Commit
6c99f7c
·
verified ·
1 Parent(s): 175fea5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -71
app.py CHANGED
@@ -10,84 +10,80 @@ AVAILABLE_MODELS = {
10
  "pythia-160m": "EleutherAI/pythia-160m"
11
  }
12
 
13
- class TextGenerator:
14
- def __init__(self):
15
- self.model = None
16
- self.tokenizer = None
17
-
18
- def load_model(self, model_name: str) -> str:
19
- try:
20
- self.model = AutoModelForCausalLM.from_pretrained(AVAILABLE_MODELS[model_name])
21
- self.tokenizer = AutoTokenizer.from_pretrained(AVAILABLE_MODELS[model_name])
22
- return f"Successfully loaded {model_name}"
23
- except Exception as e:
24
- return f"Error loading model: {str(e)}"
 
 
 
 
25
 
26
- def get_next_token_predictions(self, text: str, top_k: int = 10):
27
- if not self.model or not self.tokenizer:
28
- return [], []
29
-
30
- inputs = self.tokenizer(text, return_tensors="pt")
31
- with torch.no_grad():
32
- outputs = self.model(**inputs)
33
- logits = outputs.logits[0, -1, :]
34
- probs = torch.nn.functional.softmax(logits, dim=-1)
35
-
36
- top_k_probs, top_k_indices = torch.topk(probs, top_k)
37
- top_k_tokens = [self.tokenizer.decode([idx.item()]) for idx in top_k_indices]
38
 
39
- return top_k_tokens, top_k_probs.tolist()
 
 
 
 
40
 
41
- generator = TextGenerator()
 
 
 
 
 
 
 
42
 
43
- def format_predictions(tokens, probs):
44
- if not tokens or not probs:
45
- return "No predictions available"
46
 
47
- formatted = "Predicted next tokens:\n\n"
48
- for token, prob in zip(tokens, probs):
49
- formatted += f"'{token}' : {prob:.4f}\n"
50
- return formatted
51
-
52
- def update_output(model_name, text, custom_token, selected_token):
53
- output = text
54
 
55
- if not generator.model or generator.model.name_or_path != AVAILABLE_MODELS[model_name]:
56
- load_message = generator.load_model(model_name)
57
- if "Error" in load_message:
58
- return text, "", "", gr.update(choices=[]), load_message
 
59
 
60
- if custom_token:
61
- output += custom_token
62
- elif selected_token:
63
- output += selected_token.strip("'")
 
 
 
 
64
 
65
- tokens, probs = generator.get_next_token_predictions(output)
66
- predictions = format_predictions(tokens, probs)
67
- token_choices = [f"'{token}'" for token in tokens]
 
68
 
69
- return output, "", "", gr.update(choices=token_choices), predictions
70
-
71
- demo = gr.Interface(
72
- fn=update_output,
73
- inputs=[
74
- gr.Dropdown(choices=list(AVAILABLE_MODELS.keys()), value="distilgpt2", label="Select Model"),
75
- gr.Textbox(lines=5, label="Generated Text", placeholder="Start typing or select a token..."),
76
- gr.Textbox(label="Custom Token", placeholder="Type your own token..."),
77
- gr.Dropdown(choices=[], label="Select from predicted tokens")
78
- ],
79
- outputs=[
80
- gr.Textbox(lines=5, label="Generated Text"),
81
- gr.Textbox(label="Custom Token"),
82
- gr.Textbox(label="Selected Token"),
83
- gr.Dropdown(label="Predicted Tokens"),
84
- gr.Textbox(lines=12, label="Predictions")
85
- ],
86
- title="Interactive Text Generation",
87
- description="Generate text by selecting predicted tokens or writing your own."
88
- )
89
 
90
- if __name__ == "__main__":
91
- demo.launch()
92
- else:
93
- demo.launch(show_error=True) # Required for Hugging Face Spaces
 
10
  "pythia-160m": "EleutherAI/pythia-160m"
11
  }
12
 
13
+ generator = None
14
+
15
+ def load_model(model_name):
16
+ global generator
17
+ try:
18
+ model = AutoModelForCausalLM.from_pretrained(AVAILABLE_MODELS[model_name])
19
+ tokenizer = AutoTokenizer.from_pretrained(AVAILABLE_MODELS[model_name])
20
+ generator = (model, tokenizer)
21
+ return f"Successfully loaded {model_name}"
22
+ except Exception as e:
23
+ return f"Error loading model: {str(e)}"
24
+
25
+ def get_predictions(text, model_name):
26
+ global generator
27
+ if not generator:
28
+ load_model(model_name)
29
 
30
+ model, tokenizer = generator
31
+ inputs = tokenizer(text, return_tensors="pt")
32
+
33
+ with torch.no_grad():
34
+ outputs = model(**inputs)
35
+ logits = outputs.logits[0, -1, :]
36
+ probs = torch.nn.functional.softmax(logits, dim=-1)
 
 
 
 
 
37
 
38
+ top_k_probs, top_k_indices = torch.topk(probs, k=10)
39
+ top_k_tokens = [tokenizer.decode([idx.item()]) for idx in top_k_indices]
40
+ predictions = "\n".join([f"'{token}' : {prob:.4f}" for token, prob in zip(top_k_tokens, top_k_probs)])
41
+
42
+ return top_k_tokens, predictions
43
 
44
+ def generate(model_name, text, token_choice="", custom_token=""):
45
+ if token_choice:
46
+ text += token_choice.strip("'")
47
+ if custom_token:
48
+ text += custom_token
49
+
50
+ tokens, predictions = get_predictions(text, model_name)
51
+ return text, gr.Dropdown(choices=[f"'{t}'" for t in tokens]), predictions
52
 
53
+ with gr.Blocks() as demo:
54
+ gr.Markdown("# Interactive Text Generation")
 
55
 
56
+ model_name = gr.Dropdown(
57
+ choices=list(AVAILABLE_MODELS.keys()),
58
+ value="distilgpt2",
59
+ label="Select Model"
60
+ )
 
 
61
 
62
+ text = gr.Textbox(
63
+ lines=5,
64
+ label="Text",
65
+ placeholder="Type or select tokens to generate text..."
66
+ )
67
 
68
+ with gr.Row():
69
+ token_choice = gr.Dropdown(
70
+ choices=[],
71
+ label="Select predicted token"
72
+ )
73
+ custom_token = gr.Textbox(
74
+ label="Or type custom token"
75
+ )
76
 
77
+ predictions = gr.Textbox(
78
+ label="Predictions",
79
+ lines=10
80
+ )
81
 
82
+ for component in [model_name, token_choice, custom_token]:
83
+ component.change(
84
+ generate,
85
+ inputs=[model_name, text, token_choice, custom_token],
86
+ outputs=[text, token_choice, predictions]
87
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
+ demo.queue().launch()