yabramuvdi commited on
Commit
ac72c21
·
verified ·
1 Parent(s): 3b82846

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -61
app.py CHANGED
@@ -1,7 +1,10 @@
 
 
 
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
- import torch
4
 
 
5
  AVAILABLE_MODELS = {
6
  "distilgpt2": "distilgpt2",
7
  "bloomz-560m": "bigscience/bloomz-560m",
@@ -10,80 +13,88 @@ AVAILABLE_MODELS = {
10
  "pythia-160m": "EleutherAI/pythia-160m"
11
  }
12
 
13
- generator = None
 
 
 
 
 
 
14
 
15
  def load_model(model_name):
16
- global generator
17
- try:
18
- model = AutoModelForCausalLM.from_pretrained(AVAILABLE_MODELS[model_name])
19
- tokenizer = AutoTokenizer.from_pretrained(AVAILABLE_MODELS[model_name])
20
- generator = (model, tokenizer)
21
- return f"Successfully loaded {model_name}"
22
- except Exception as e:
23
- return f"Error loading model: {str(e)}"
24
 
25
- def get_predictions(text, model_name):
26
- global generator
27
- if not generator:
28
- load_model(model_name)
29
 
30
- model, tokenizer = generator
31
- inputs = tokenizer(text, return_tensors="pt")
 
32
 
 
 
33
  with torch.no_grad():
34
- outputs = model(**inputs)
35
  logits = outputs.logits[0, -1, :]
36
  probs = torch.nn.functional.softmax(logits, dim=-1)
37
 
38
- top_k_probs, top_k_indices = torch.topk(probs, k=10)
39
- top_k_tokens = [tokenizer.decode([idx.item()]) for idx in top_k_indices]
40
- predictions = "\n".join([f"'{token}' : {prob:.4f}" for token, prob in zip(top_k_tokens, top_k_probs)])
41
 
42
- return top_k_tokens, predictions
43
 
44
- def generate(model_name, text, token_choice="", custom_token=""):
45
- if token_choice:
46
- text += token_choice.strip("'")
47
  if custom_token:
48
  text += custom_token
49
-
50
- tokens, predictions = get_predictions(text, model_name)
51
- return text, gr.Dropdown(choices=[f"'{t}'" for t in tokens]), predictions
52
-
53
- with gr.Blocks() as demo:
54
- gr.Markdown("# Interactive Text Generation")
55
 
56
- model_name = gr.Dropdown(
57
- choices=list(AVAILABLE_MODELS.keys()),
58
- value="distilgpt2",
59
- label="Select Model"
60
- )
61
 
62
- text = gr.Textbox(
63
- lines=5,
64
- label="Text",
65
- placeholder="Type or select tokens to generate text..."
66
- )
67
 
68
- with gr.Row():
69
- token_choice = gr.Dropdown(
70
- choices=[],
71
- label="Select predicted token"
72
- )
73
- custom_token = gr.Textbox(
74
- label="Or type custom token"
75
- )
76
-
77
- predictions = gr.Textbox(
78
- label="Predictions",
79
- lines=10
80
- )
81
-
82
- for component in [model_name, token_choice, custom_token]:
83
- component.change(
84
- generate,
85
- inputs=[model_name, text, token_choice, custom_token],
86
- outputs=[text, token_choice, predictions]
87
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- demo.queue().launch(share=True)
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
  import gradio as gr
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
6
 
7
+ # Available models
8
  AVAILABLE_MODELS = {
9
  "distilgpt2": "distilgpt2",
10
  "bloomz-560m": "bigscience/bloomz-560m",
 
13
  "pythia-160m": "EleutherAI/pythia-160m"
14
  }
15
 
16
+ # Access token for Hugging Face
17
+ HF_TOKEN = os.getenv('HF_TOKEN')
18
+
19
+ # Initialize model and tokenizer globally
20
+ current_model = None
21
+ current_tokenizer = None
22
+ current_model_name = None
23
 
24
  def load_model(model_name):
25
+ global current_model, current_tokenizer, current_model_name
26
+ if current_model_name != model_name:
27
+ current_model = AutoModelForCausalLM.from_pretrained(AVAILABLE_MODELS[model_name], use_auth_token=HF_TOKEN)
28
+ current_tokenizer = AutoTokenizer.from_pretrained(AVAILABLE_MODELS[model_name], use_auth_token=HF_TOKEN)
29
+ current_model_name = model_name
 
 
 
30
 
31
+ def get_next_token_predictions(text, model_name, top_k=10):
32
+ global current_model, current_tokenizer
 
 
33
 
34
+ # Load model if needed
35
+ if current_model_name != model_name:
36
+ load_model(model_name)
37
 
38
+ # Get predictions
39
+ inputs = current_tokenizer(text, return_tensors="pt")
40
  with torch.no_grad():
41
+ outputs = current_model(**inputs)
42
  logits = outputs.logits[0, -1, :]
43
  probs = torch.nn.functional.softmax(logits, dim=-1)
44
 
45
+ top_k_probs, top_k_indices = torch.topk(probs, k=top_k)
46
+ top_k_tokens = [current_tokenizer.decode([idx.item()]) for idx in top_k_indices]
 
47
 
48
+ return top_k_tokens, top_k_probs.tolist()
49
 
50
+ def predict_next_token(text, model_name, custom_token=""):
51
+ # Add custom token if provided
 
52
  if custom_token:
53
  text += custom_token
 
 
 
 
 
 
54
 
55
+ # Get predictions
56
+ tokens, probs = get_next_token_predictions(text, model_name)
 
 
 
57
 
58
+ # Format predictions
59
+ predictions = "\n".join([f"'{token}' : {prob:.4f}" for token, prob in zip(tokens, probs)])
 
 
 
60
 
61
+ return text, gr.Dropdown(choices=[f"'{t}'" for t in tokens]), predictions
62
+
63
+ # Page content
64
+ title = "Interactive Text Generation with Transformer Models"
65
+ description = """
66
+ This application allows you to interactively generate text using various transformer models.
67
+ You can either select from the predicted next tokens or write your own tokens to continue the text generation.
68
+
69
+ Select a model, start typing or choose from the predicted tokens, and see how the model continues your text!
70
+ """
71
+
72
+ # Example inputs
73
+ examples = [
74
+ ["The quick brown fox", "distilgpt2"],
75
+ ["In a galaxy far", "gpt2-medium"],
76
+ ["Once upon a time", "opt-350m"],
77
+ ]
78
+
79
+ # Create the interface
80
+ app = gr.Interface(
81
+ fn=predict_next_token,
82
+ inputs=[
83
+ gr.Textbox(lines=5, label="Text"),
84
+ gr.Dropdown(choices=list(AVAILABLE_MODELS.keys()), value="distilgpt2", label="Model"),
85
+ gr.Textbox(label="Custom token (optional)")
86
+ ],
87
+ outputs=[
88
+ gr.Textbox(lines=5, label="Generated text"),
89
+ gr.Dropdown(label="Predicted tokens"),
90
+ gr.Textbox(lines=10, label="Token probabilities")
91
+ ],
92
+ theme="huggingface",
93
+ title=title,
94
+ description=description,
95
+ examples=examples,
96
+ allow_flagging="manual"
97
+ )
98
 
99
+ # Launch the app
100
+ app.launch()