File size: 4,605 Bytes
61e1114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
from model import build_transformer
from train import greedy_decode, get_model, get_or_build_tokenizer

from config import get_config, get_weights_file_path
from tokenizers import Tokenizer
from pathlib import Path



config = get_config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def process_text(config, src_text, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len):
        seq_len = seq_len

        # ds = ds
        tokenizer_src = tokenizer_src
        tokenizer_tgt = tokenizer_tgt
        src_lang = src_lang
        tgt_lang = tgt_lang

        sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64)
        eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64)
        pad_token = torch.tensor([tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64)
    # Transform the text into tokens
        enc_input_tokens = tokenizer_src.encode(src_text).ids
        # dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids

        # Add sos, eos and padding to each sentence
        enc_num_padding_tokens = seq_len - len(enc_input_tokens) - 2  # We will add <s> and </s>
        # # We will only add <s>, and </s> only on the label
        # dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1

        # Make sure the number of padding tokens is not negative. If it is, the sentence is too long
        if enc_num_padding_tokens < 0:
            raise ValueError("Sentence is too long")

        # Add <s> and </s> token
        encoder_input = torch.cat(
            [
                sos_token,
                torch.tensor(enc_input_tokens, dtype=torch.int64),
                eos_token,
                torch.tensor([pad_token] * enc_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        # # Add only <s> token
        # decoder_input = torch.cat(
        #     [
        #         self.sos_token,
        #         torch.tensor(dec_input_tokens, dtype=torch.int64),
        #         torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
        #     ],
        #     dim=0,
        # )

        # # Add only </s> token
        # label = torch.cat(
        #     [
        #         torch.tensor(dec_input_tokens, dtype=torch.int64),
        #         self.eos_token,
        #         torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
        #     ],
        #     dim=0,
        # )

        # Double check the size of the tensors to make sure they are all seq_len long
        assert encoder_input.size(0) == seq_len
        # assert decoder_input.size(0) == seq_len
        # assert label.size(0) == seq_len
        return {
            'encoder_input': encoder_input,
            # 'decoder_input': decoder_input,
            "encoder_mask": (encoder_input != pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len)
            # "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len) & (1, seq_len, seq_len),
            # "label": label,  # (seq_len)
             
            # "src_text": src_text,
            # "tgt_text": tgt_text,
        }
            
def causal_mask(size):
    mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
    return mask == 0

def infer(text, config):
    tokenizer_src = Tokenizer.from_file(str(Path(config['tokenizer_file'].format(config['lang_src']))))
    tokenizer_tgt = Tokenizer.from_file(str(Path(config['tokenizer_file'].format(config['lang_tgt']))))
    model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size())
    state = torch.load('tmodel_36.pt', map_location=torch.device('cpu'))
    model.load_state_dict(state['model_state_dict'])




    model.eval()
    with torch.no_grad():
        processed_text = process_text(config, text, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
        encoder_input = processed_text['encoder_input']
        encoder_mask = processed_text['encoder_mask']

        model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, config['seq_len'], device)
        model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())
        return model_out_text
    

import streamlit as st

st.title("English to Hausa Translation")

user_input = st.text_input("Enter your text:")
if user_input:
    result = infer(user_input, config)
    st.write("Inference Result:", result)