aarohanverma commited on
Commit
a47e87d
·
verified ·
1 Parent(s): 767b8cd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -0
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import sentencepiece as spm
4
+
5
+ # ---------------------- Model & SentencePiece Loading ----------------------
6
+ @torch.no_grad()
7
+ def load_model():
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+ model = torch.jit.load("best_model_scripted.pt", map_location=device).eval()
10
+ return model, device
11
+
12
+ def load_sp_model():
13
+ sp = spm.SentencePieceProcessor()
14
+ sp.load("spm.model")
15
+ return sp
16
+
17
+ # Cache models globally
18
+ model, device = load_model()
19
+ sp = load_sp_model()
20
+
21
+ # ---------------------- Prediction Function ----------------------
22
+ @torch.no_grad()
23
+ def predict_next_words(text, max_predictions=3):
24
+ if not text.strip():
25
+ return []
26
+
27
+ token_ids = sp.encode(text.strip(), out_type=int)
28
+ if len(token_ids) == 0:
29
+ return []
30
+
31
+ input_seq = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(device)
32
+ logits = model(input_seq)
33
+ probabilities = torch.softmax(logits, dim=-1)
34
+ top_indices = torch.topk(probabilities, k=max_predictions, dim=-1).indices.squeeze(0).tolist()
35
+
36
+ predicted_pieces = [sp.id_to_piece(idx).lstrip("▁") for idx in top_indices]
37
+ return predicted_pieces
38
+
39
+ # ---------------------- Gradio App Functions ----------------------
40
+ def submit_and_predict(text):
41
+ suggestions = predict_next_words(text)
42
+ suggestions += [""] * (3 - len(suggestions)) # Ensure 3 buttons always
43
+ return suggestions
44
+
45
+ def append_suggestion(text, suggestion):
46
+ if suggestion:
47
+ text += suggestion + " "
48
+ return text
49
+
50
+ # ---------------------- Gradio Interface ----------------------
51
+ with gr.Blocks(title="Next Word Predictor") as app:
52
+ gr.Markdown("# Next Word Prediction")
53
+ gr.Markdown("Enter text and click 'Submit' to get word suggestions.")
54
+
55
+ text_input = gr.Textbox(label="Your Text", placeholder="Type here...", lines=3)
56
+ submit_btn = gr.Button("Submit", variant="primary")
57
+
58
+ with gr.Row():
59
+ suggestion_buttons = [gr.Button(visible=False) for _ in range(3)]
60
+
61
+ submit_btn.click(
62
+ fn=submit_and_predict,
63
+ inputs=text_input,
64
+ outputs=suggestion_buttons,
65
+ )
66
+
67
+ for btn in suggestion_buttons:
68
+ btn.click(
69
+ fn=append_suggestion,
70
+ inputs=[text_input, btn],
71
+ outputs=text_input
72
+ )
73
+
74
+ if __name__ == "__main__":
75
+ app.launch()