yabramuvdi commited on
Commit
617bd81
·
verified ·
1 Parent(s): 4574633

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -0
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+ from typing import Tuple, List, Dict
5
+ import numpy as np
6
+
7
+ # Select smaller models that are suitable for this task
8
+ AVAILABLE_MODELS = {
9
+ "distilgpt2": "distilgpt2",
10
+ "bloomz-560m": "bigscience/bloomz-560m",
11
+ "gpt2-medium": "gpt2-medium",
12
+ "opt-350m": "facebook/opt-350m",
13
+ "pythia-160m": "EleutherAI/pythia-160m"
14
+ }
15
+
16
+ class TextGenerator:
17
+ def __init__(self):
18
+ self.model = None
19
+ self.tokenizer = None
20
+
21
+ def load_model(self, model_name: str) -> str:
22
+ """Load the selected model and tokenizer"""
23
+ try:
24
+ self.model = AutoModelForCausalLM.from_pretrained(AVAILABLE_MODELS[model_name])
25
+ self.tokenizer = AutoTokenizer.from_pretrained(AVAILABLE_MODELS[model_name])
26
+ return f"Successfully loaded {model_name}"
27
+ except Exception as e:
28
+ return f"Error loading model: {str(e)}"
29
+
30
+ def get_next_token_predictions(self, text: str, top_k: int = 10) -> Tuple[List[str], List[float]]:
31
+ """Get predictions for the next token"""
32
+ if not self.model or not self.tokenizer:
33
+ return [], []
34
+
35
+ inputs = self.tokenizer(text, return_tensors="pt")
36
+ with torch.no_grad():
37
+ outputs = self.model(**inputs)
38
+ logits = outputs.logits[0, -1, :]
39
+ probs = torch.nn.functional.softmax(logits, dim=-1)
40
+
41
+ top_k_probs, top_k_indices = torch.topk(probs, top_k)
42
+ top_k_tokens = [self.tokenizer.decode([idx.item()]) for idx in top_k_indices]
43
+ top_k_probs = top_k_probs.tolist()
44
+
45
+ return top_k_tokens, top_k_probs
46
+
47
+ def format_predictions(tokens: List[str], probs: List[float]) -> str:
48
+ """Format the predictions for display"""
49
+ if not tokens or not probs:
50
+ return "No predictions available"
51
+
52
+ formatted = "Predicted next tokens:\n\n"
53
+ for token, prob in zip(tokens, probs):
54
+ formatted += f"'{token}' : {prob:.4f}\n"
55
+ return formatted
56
+
57
+ generator = TextGenerator()
58
+
59
+ def update_output(model_name: str, text: str, custom_token: str, selected_token: str) -> Tuple[str, str, str, Dict, str]:
60
+ """Update the interface based on user interactions"""
61
+ output = text
62
+
63
+ # Load model if it changed
64
+ if not generator.model or generator.model.name_or_path != AVAILABLE_MODELS[model_name]:
65
+ load_message = generator.load_model(model_name)
66
+ if "Error" in load_message:
67
+ return text, "", "", gr.update(choices=[]), load_message
68
+
69
+ # Add custom token or selected token
70
+ if custom_token:
71
+ output += custom_token
72
+ elif selected_token:
73
+ output += selected_token.strip("'")
74
+
75
+ # Get new predictions
76
+ tokens, probs = generator.get_next_token_predictions(output)
77
+ predictions = format_predictions(tokens, probs)
78
+
79
+ # Update dropdown choices
80
+ token_choices = [f"'{token}'" for token in tokens]
81
+
82
+ return output, "", "", gr.update(choices=token_choices), predictions
83
+
84
+ with gr.Blocks() as app:
85
+ gr.Markdown("# Interactive Text Generation")
86
+
87
+ with gr.Row():
88
+ model_dropdown = gr.Dropdown(
89
+ choices=list(AVAILABLE_MODELS.keys()),
90
+ value="distilgpt2",
91
+ label="Select Model"
92
+ )
93
+
94
+ with gr.Row():
95
+ text_input = gr.Textbox(
96
+ lines=5,
97
+ label="Generated Text",
98
+ placeholder="Start typing or select a token..."
99
+ )
100
+
101
+ with gr.Row():
102
+ custom_token = gr.Textbox(
103
+ label="Custom Token",
104
+ placeholder="Type your own token..."
105
+ )
106
+ token_dropdown = gr.Dropdown(
107
+ choices=[],
108
+ label="Select from predicted tokens"
109
+ )
110
+
111
+ with gr.Row():
112
+ predictions_output = gr.Textbox(
113
+ label="Predictions",
114
+ lines=12
115
+ )
116
+
117
+ with gr.Row():
118
+ status_output = gr.Textbox(
119
+ label="Status",
120
+ lines=1
121
+ )
122
+
123
+ # Update when model changes or token is added
124
+ for trigger in [model_dropdown, custom_token, token_dropdown]:
125
+ trigger.change(
126
+ fn=update_output,
127
+ inputs=[model_dropdown, text_input, custom_token, token_dropdown],
128
+ outputs=[text_input, custom_token, token_dropdown, token_dropdown, predictions_output]
129
+ )
130
+
131
+ if __name__ == "__main__":
132
+ app.launch()