Spaces:
Sleeping
Sleeping
henok3878
commited on
Commit
·
713ed4b
1
Parent(s):
a1f27d5
add app.py with basic gradio interface and translation logic
Browse files
app.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import gradio as gr
|
3 |
+
from tokenizers import Tokenizer
|
4 |
+
|
5 |
+
from transformer.config import load_config
|
6 |
+
from transformer.components.decoding import beam_search
|
7 |
+
from transformer.transformer import Transformer
|
8 |
+
|
9 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
10 |
+
|
11 |
+
CONFIG_PATH = "configs/config.yaml"
|
12 |
+
MODEL_PATH = "model_checkpoint.pt"
|
13 |
+
TOKENIZER_PATH = "tokenizers/tokenizer-joint-de-en-vocab37000.json"
|
14 |
+
MAX_LEN = 128
|
15 |
+
|
16 |
+
config = load_config(CONFIG_PATH)
|
17 |
+
tokenizer = Tokenizer.from_file(TOKENIZER_PATH)
|
18 |
+
padding_idx = tokenizer.token_to_id("[PAD]")
|
19 |
+
|
20 |
+
model = Transformer.load_from_checkpoint(checkpoint_path=MODEL_PATH, config=config, device=DEVICE)
|
21 |
+
|
22 |
+
def translate(text: str, beam_size: int = 4) -> str:
|
23 |
+
src_ids = torch.tensor([tokenizer.encode(text).ids], device=DEVICE)
|
24 |
+
src_mask = (src_ids != padding_idx).unsqueeze(1).unsqueeze(2)
|
25 |
+
with torch.no_grad():
|
26 |
+
result_ids = beam_search(
|
27 |
+
model,
|
28 |
+
src_ids,
|
29 |
+
src_mask,
|
30 |
+
tokenizer,
|
31 |
+
max_len=MAX_LEN,
|
32 |
+
beam_size=beam_size,
|
33 |
+
)[0]
|
34 |
+
return tokenizer.decode(result_ids, skip_special_tokens=True)
|
35 |
+
|
36 |
+
with gr.Blocks(title="Transformer From Scratch Translation Demo") as demo:
|
37 |
+
gr.Markdown(
|
38 |
+
"# Transformer From Scratch Translation Demo\n"
|
39 |
+
"Translate English to German using a custom Transformer model trained from scratch."
|
40 |
+
)
|
41 |
+
with gr.Row(equal_height=True):
|
42 |
+
with gr.Column():
|
43 |
+
input_text = gr.Textbox(
|
44 |
+
label="English Text",
|
45 |
+
placeholder="Enter text to translate...",
|
46 |
+
lines=3
|
47 |
+
)
|
48 |
+
beam_size = gr.Slider(
|
49 |
+
minimum=1, maximum=8, step=1, value=4, label="Beam Size"
|
50 |
+
)
|
51 |
+
with gr.Column():
|
52 |
+
output_text = gr.Textbox(
|
53 |
+
label="German Translation",
|
54 |
+
lines=3,
|
55 |
+
interactive=False,
|
56 |
+
show_copy_button=True,
|
57 |
+
show_label=True
|
58 |
+
)
|
59 |
+
with gr.Row():
|
60 |
+
with gr.Column(scale=1):
|
61 |
+
pass
|
62 |
+
with gr.Column(scale=2, min_width=300, elem_id="centered-controls"):
|
63 |
+
translate_btn = gr.Button("Translate")
|
64 |
+
gr.Examples(
|
65 |
+
examples=[
|
66 |
+
["Hello, how are you?"],
|
67 |
+
["The weather is nice today."],
|
68 |
+
["I love machine learning."],
|
69 |
+
],
|
70 |
+
inputs=[input_text]
|
71 |
+
)
|
72 |
+
with gr.Column(scale=1):
|
73 |
+
pass
|
74 |
+
|
75 |
+
translate_btn.click(
|
76 |
+
translate,
|
77 |
+
inputs=[input_text, beam_size],
|
78 |
+
outputs=[output_text]
|
79 |
+
)
|
80 |
+
|
81 |
+
if __name__ == "__main__":
|
82 |
+
demo.launch()
|