henok3878 commited on
Commit
713ed4b
·
1 Parent(s): a1f27d5

add app.py with basic gradio interface and translation logic

Browse files
Files changed (1) hide show
  1. app.py +82 -0
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from tokenizers import Tokenizer
4
+
5
+ from transformer.config import load_config
6
+ from transformer.components.decoding import beam_search
7
+ from transformer.transformer import Transformer
8
+
9
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ CONFIG_PATH = "configs/config.yaml"
12
+ MODEL_PATH = "model_checkpoint.pt"
13
+ TOKENIZER_PATH = "tokenizers/tokenizer-joint-de-en-vocab37000.json"
14
+ MAX_LEN = 128
15
+
16
+ config = load_config(CONFIG_PATH)
17
+ tokenizer = Tokenizer.from_file(TOKENIZER_PATH)
18
+ padding_idx = tokenizer.token_to_id("[PAD]")
19
+
20
+ model = Transformer.load_from_checkpoint(checkpoint_path=MODEL_PATH, config=config, device=DEVICE)
21
+
22
+ def translate(text: str, beam_size: int = 4) -> str:
23
+ src_ids = torch.tensor([tokenizer.encode(text).ids], device=DEVICE)
24
+ src_mask = (src_ids != padding_idx).unsqueeze(1).unsqueeze(2)
25
+ with torch.no_grad():
26
+ result_ids = beam_search(
27
+ model,
28
+ src_ids,
29
+ src_mask,
30
+ tokenizer,
31
+ max_len=MAX_LEN,
32
+ beam_size=beam_size,
33
+ )[0]
34
+ return tokenizer.decode(result_ids, skip_special_tokens=True)
35
+
36
+ with gr.Blocks(title="Transformer From Scratch Translation Demo") as demo:
37
+ gr.Markdown(
38
+ "# Transformer From Scratch Translation Demo\n"
39
+ "Translate English to German using a custom Transformer model trained from scratch."
40
+ )
41
+ with gr.Row(equal_height=True):
42
+ with gr.Column():
43
+ input_text = gr.Textbox(
44
+ label="English Text",
45
+ placeholder="Enter text to translate...",
46
+ lines=3
47
+ )
48
+ beam_size = gr.Slider(
49
+ minimum=1, maximum=8, step=1, value=4, label="Beam Size"
50
+ )
51
+ with gr.Column():
52
+ output_text = gr.Textbox(
53
+ label="German Translation",
54
+ lines=3,
55
+ interactive=False,
56
+ show_copy_button=True,
57
+ show_label=True
58
+ )
59
+ with gr.Row():
60
+ with gr.Column(scale=1):
61
+ pass
62
+ with gr.Column(scale=2, min_width=300, elem_id="centered-controls"):
63
+ translate_btn = gr.Button("Translate")
64
+ gr.Examples(
65
+ examples=[
66
+ ["Hello, how are you?"],
67
+ ["The weather is nice today."],
68
+ ["I love machine learning."],
69
+ ],
70
+ inputs=[input_text]
71
+ )
72
+ with gr.Column(scale=1):
73
+ pass
74
+
75
+ translate_btn.click(
76
+ translate,
77
+ inputs=[input_text, beam_size],
78
+ outputs=[output_text]
79
+ )
80
+
81
+ if __name__ == "__main__":
82
+ demo.launch()