|
import gradio as gr |
|
import torch |
|
import sentencepiece as spm |
|
|
|
|
|
@torch.no_grad() |
|
def load_model(): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = torch.jit.load("best_model_scripted.pt", map_location=device).eval() |
|
return model, device |
|
|
|
def load_sp_model(): |
|
sp = spm.SentencePieceProcessor() |
|
sp.load("spm.model") |
|
return sp |
|
|
|
|
|
model, device = load_model() |
|
sp = load_sp_model() |
|
|
|
|
|
@torch.no_grad() |
|
def predict_next_words(text, max_predictions=3): |
|
"""Predict up to max_predictions next words.""" |
|
text = text.strip().lower() |
|
if not text: |
|
return [] |
|
|
|
token_ids = sp.encode(text, out_type=int) |
|
if not token_ids: |
|
return [] |
|
|
|
input_seq = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(device) |
|
logits = model(input_seq) |
|
probabilities = torch.softmax(logits, dim=-1) |
|
top_indices = torch.topk(probabilities, k=max_predictions, dim=-1).indices.squeeze(0).tolist() |
|
|
|
predicted_pieces = [sp.id_to_piece(idx).lstrip("▁") for idx in top_indices] |
|
return predicted_pieces |
|
|
|
|
|
def submit_and_predict(text): |
|
|
|
suggestions = predict_next_words(text) |
|
suggestions += [""] * (3 - len(suggestions)) |
|
|
|
|
|
updates = [] |
|
for s in suggestions: |
|
if s: |
|
updates.append(gr.update(value=s, visible=True)) |
|
else: |
|
updates.append(gr.update(value="", visible=False)) |
|
return updates |
|
|
|
def append_suggestion(text, suggestion): |
|
|
|
if suggestion: |
|
text = text.rstrip() + " " + suggestion + " " |
|
return text |
|
|
|
|
|
with gr.Blocks(title="Next Word Predictor") as app: |
|
gr.Markdown("# Next Word Prediction") |
|
gr.Markdown("Enter text and click 'Submit' to get word suggestions.") |
|
|
|
text_input = gr.Textbox(label="Your Text", placeholder="Type here...", lines=3) |
|
submit_btn = gr.Button("Submit", variant="primary") |
|
|
|
with gr.Row(): |
|
suggestion_buttons = [gr.Button(visible=False) for _ in range(3)] |
|
|
|
|
|
submit_btn.click( |
|
fn=submit_and_predict, |
|
inputs=text_input, |
|
outputs=suggestion_buttons, |
|
) |
|
|
|
|
|
for btn in suggestion_buttons: |
|
btn.click( |
|
fn=append_suggestion, |
|
inputs=[text_input, btn], |
|
outputs=text_input |
|
) |
|
|
|
if __name__ == "__main__": |
|
app.launch() |
|
|