import torch import pickle import sys import os from inference.language import Language from inference.utility import Encoder, Decoder, encoderBlock, decoderBlock, MultiHeadAttention, Head, FeedForward import warnings from typing import List warnings.filterwarnings("ignore", category=FutureWarning) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") with open(os.path.join(os.path.dirname(__file__), 'input_lang.pkl'), "rb") as file: input_lang = pickle.load(file) with open(os.path.join(os.path.dirname(__file__), 'output_lang.pkl'), "rb") as file: output_lang = pickle.load(file) encoder = torch.load(os.path.join(os.path.dirname(__file__), 'encoder (1).pth'), map_location=device) decoder = torch.load(os.path.join(os.path.dirname(__file__), 'decoder (1).pth'), map_location=device) input_vocab_size = input_lang.vocab_size output_vocab_size = output_lang.vocab_size def encode(s): return [input_lang.char2index.get(ch, input_lang.char2index['$']) for ch in s] def generate(input: List[str]) -> List[str]: # pre-process the input: same length and max_length = 33 for i, inp in enumerate(input): input[i] = input[i][:33] if len(input[i]) > 33 else input[i].ljust(33, '#') input = torch.tensor([encode(i) for i in input], device=device, dtype=torch.long) B, T = input.shape encoder_output = encoder(input) idx = torch.full((B, 1), 2, dtype=torch.long, device=device) # (B,1) # idx is (B, T) array of indices in the current context for _ in range(30): # get the predictions logits, loss = decoder(idx, encoder_output) # logits (B, T, vocab_size) # focus only on the last time step logits = logits[:, -1, :] # becomes (B, C) # apply softmax to get probabilities idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (B, 1) # append sampled index to the running sequence idx = torch.cat((idx, idx_next), dim=1) # (B, T+1) ans = [] for id in idx: ans.append(output_lang.decode(id.tolist()[1:]).split('#', 1)[0]) return ans