aarohanverma commited on
Commit
b1d4b68
·
verified ·
1 Parent(s): dc9fd64

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -6
app.py CHANGED
@@ -21,11 +21,13 @@ sp = load_sp_model()
21
  # ---------------------- Prediction Function ----------------------
22
  @torch.no_grad()
23
  def predict_next_words(text, max_predictions=3):
24
- if not text.strip():
 
 
25
  return []
26
 
27
- token_ids = sp.encode(text.strip(), out_type=int)
28
- if len(token_ids) == 0:
29
  return []
30
 
31
  input_seq = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(device)
@@ -38,13 +40,23 @@ def predict_next_words(text, max_predictions=3):
38
 
39
  # ---------------------- Gradio App Functions ----------------------
40
  def submit_and_predict(text):
 
41
  suggestions = predict_next_words(text)
42
- suggestions += [""] * (3 - len(suggestions)) # Ensure 3 buttons always
43
- return suggestions
 
 
 
 
 
 
 
 
44
 
45
  def append_suggestion(text, suggestion):
 
46
  if suggestion:
47
- text += suggestion + " "
48
  return text
49
 
50
  # ---------------------- Gradio Interface ----------------------
@@ -58,12 +70,14 @@ with gr.Blocks(title="Next Word Predictor") as app:
58
  with gr.Row():
59
  suggestion_buttons = [gr.Button(visible=False) for _ in range(3)]
60
 
 
61
  submit_btn.click(
62
  fn=submit_and_predict,
63
  inputs=text_input,
64
  outputs=suggestion_buttons,
65
  )
66
 
 
67
  for btn in suggestion_buttons:
68
  btn.click(
69
  fn=append_suggestion,
 
21
  # ---------------------- Prediction Function ----------------------
22
  @torch.no_grad()
23
  def predict_next_words(text, max_predictions=3):
24
+ """Predict up to max_predictions next words."""
25
+ text = text.strip()
26
+ if not text:
27
  return []
28
 
29
+ token_ids = sp.encode(text, out_type=int)
30
+ if not token_ids:
31
  return []
32
 
33
  input_seq = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(device)
 
40
 
41
  # ---------------------- Gradio App Functions ----------------------
42
  def submit_and_predict(text):
43
+ # Get predictions and ensure exactly 3 by padding empty strings.
44
  suggestions = predict_next_words(text)
45
+ suggestions += [""] * (3 - len(suggestions))
46
+
47
+ # Return an array of Gradio "update" objects so we can make them visible or hidden.
48
+ updates = []
49
+ for s in suggestions:
50
+ if s: # Valid prediction
51
+ updates.append(gr.update(value=s, visible=True))
52
+ else: # No prediction
53
+ updates.append(gr.update(value="", visible=False))
54
+ return updates
55
 
56
  def append_suggestion(text, suggestion):
57
+ # Only append if not empty.
58
  if suggestion:
59
+ text = text.rstrip() + " " + suggestion + " "
60
  return text
61
 
62
  # ---------------------- Gradio Interface ----------------------
 
70
  with gr.Row():
71
  suggestion_buttons = [gr.Button(visible=False) for _ in range(3)]
72
 
73
+ # 1. When user clicks 'Submit', run submit_and_predict to get suggestions.
74
  submit_btn.click(
75
  fn=submit_and_predict,
76
  inputs=text_input,
77
  outputs=suggestion_buttons,
78
  )
79
 
80
+ # 2. Each suggestion button appends the chosen word to the main text.
81
  for btn in suggestion_buttons:
82
  btn.click(
83
  fn=append_suggestion,