File size: 2,089 Bytes
9e582c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import pickle
import sys
import os
from inference.language import Language
from inference.utility import Encoder, Decoder, encoderBlock, decoderBlock, MultiHeadAttention, Head, FeedForward
import warnings
from typing import List
warnings.filterwarnings("ignore", category=FutureWarning)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with open(os.path.join(os.path.dirname(__file__), 'input_lang.pkl'), "rb") as file:
    input_lang = pickle.load(file)

with open(os.path.join(os.path.dirname(__file__), 'output_lang.pkl'), "rb") as file:
    output_lang = pickle.load(file)

encoder = torch.load(os.path.join(os.path.dirname(__file__), 'encoder (1).pth'), map_location=device)
decoder = torch.load(os.path.join(os.path.dirname(__file__), 'decoder (1).pth'), map_location=device)

input_vocab_size = input_lang.vocab_size
output_vocab_size = output_lang.vocab_size

def encode(s):
    return [input_lang.char2index.get(ch, input_lang.char2index['$']) for ch in s]

def generate(input: List[str]) -> List[str]:
    # pre-process the input: same length and max_length = 33
    for i, inp in enumerate(input):
        input[i] = input[i][:33] if len(input[i]) > 33 else input[i].ljust(33, '#')

    input = torch.tensor([encode(i) for i in input], device=device, dtype=torch.long)
    B, T = input.shape

    encoder_output = encoder(input)
    idx = torch.full((B, 1), 2, dtype=torch.long, device=device) # (B,1)

    # idx is (B, T) array of indices in the current context
    for _ in range(30):
        # get the predictions
        logits, loss = decoder(idx, encoder_output) # logits (B, T, vocab_size)
        # focus only on the last time step
        logits = logits[:, -1, :] # becomes (B, C)
        # apply softmax to get probabilities
        idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (B, 1)
        # append sampled index to the running sequence
        idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)

    ans = []
    for id in idx:
        ans.append(output_lang.decode(id.tolist()[1:]).split('#', 1)[0])
    return ans