Update app.py
Browse files
app.py
CHANGED
@@ -21,11 +21,13 @@ sp = load_sp_model()
|
|
21 |
# ---------------------- Prediction Function ----------------------
|
22 |
@torch.no_grad()
|
23 |
def predict_next_words(text, max_predictions=3):
|
24 |
-
|
|
|
|
|
25 |
return []
|
26 |
|
27 |
-
token_ids = sp.encode(text
|
28 |
-
if
|
29 |
return []
|
30 |
|
31 |
input_seq = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(device)
|
@@ -38,13 +40,23 @@ def predict_next_words(text, max_predictions=3):
|
|
38 |
|
39 |
# ---------------------- Gradio App Functions ----------------------
|
40 |
def submit_and_predict(text):
|
|
|
41 |
suggestions = predict_next_words(text)
|
42 |
-
suggestions += [""] * (3 - len(suggestions))
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
def append_suggestion(text, suggestion):
|
|
|
46 |
if suggestion:
|
47 |
-
text
|
48 |
return text
|
49 |
|
50 |
# ---------------------- Gradio Interface ----------------------
|
@@ -58,12 +70,14 @@ with gr.Blocks(title="Next Word Predictor") as app:
|
|
58 |
with gr.Row():
|
59 |
suggestion_buttons = [gr.Button(visible=False) for _ in range(3)]
|
60 |
|
|
|
61 |
submit_btn.click(
|
62 |
fn=submit_and_predict,
|
63 |
inputs=text_input,
|
64 |
outputs=suggestion_buttons,
|
65 |
)
|
66 |
|
|
|
67 |
for btn in suggestion_buttons:
|
68 |
btn.click(
|
69 |
fn=append_suggestion,
|
|
|
21 |
# ---------------------- Prediction Function ----------------------
|
22 |
@torch.no_grad()
|
23 |
def predict_next_words(text, max_predictions=3):
|
24 |
+
"""Predict up to max_predictions next words."""
|
25 |
+
text = text.strip()
|
26 |
+
if not text:
|
27 |
return []
|
28 |
|
29 |
+
token_ids = sp.encode(text, out_type=int)
|
30 |
+
if not token_ids:
|
31 |
return []
|
32 |
|
33 |
input_seq = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(device)
|
|
|
40 |
|
41 |
# ---------------------- Gradio App Functions ----------------------
|
42 |
def submit_and_predict(text):
|
43 |
+
# Get predictions and ensure exactly 3 by padding empty strings.
|
44 |
suggestions = predict_next_words(text)
|
45 |
+
suggestions += [""] * (3 - len(suggestions))
|
46 |
+
|
47 |
+
# Return an array of Gradio "update" objects so we can make them visible or hidden.
|
48 |
+
updates = []
|
49 |
+
for s in suggestions:
|
50 |
+
if s: # Valid prediction
|
51 |
+
updates.append(gr.update(value=s, visible=True))
|
52 |
+
else: # No prediction
|
53 |
+
updates.append(gr.update(value="", visible=False))
|
54 |
+
return updates
|
55 |
|
56 |
def append_suggestion(text, suggestion):
|
57 |
+
# Only append if not empty.
|
58 |
if suggestion:
|
59 |
+
text = text.rstrip() + " " + suggestion + " "
|
60 |
return text
|
61 |
|
62 |
# ---------------------- Gradio Interface ----------------------
|
|
|
70 |
with gr.Row():
|
71 |
suggestion_buttons = [gr.Button(visible=False) for _ in range(3)]
|
72 |
|
73 |
+
# 1. When user clicks 'Submit', run submit_and_predict to get suggestions.
|
74 |
submit_btn.click(
|
75 |
fn=submit_and_predict,
|
76 |
inputs=text_input,
|
77 |
outputs=suggestion_buttons,
|
78 |
)
|
79 |
|
80 |
+
# 2. Each suggestion button appends the chosen word to the main text.
|
81 |
for btn in suggestion_buttons:
|
82 |
btn.click(
|
83 |
fn=append_suggestion,
|