Spaces:
Running
Running
File size: 2,089 Bytes
9e582c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import torch
import pickle
import sys
import os
from inference.language import Language
from inference.utility import Encoder, Decoder, encoderBlock, decoderBlock, MultiHeadAttention, Head, FeedForward
import warnings
from typing import List
warnings.filterwarnings("ignore", category=FutureWarning)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with open(os.path.join(os.path.dirname(__file__), 'input_lang.pkl'), "rb") as file:
input_lang = pickle.load(file)
with open(os.path.join(os.path.dirname(__file__), 'output_lang.pkl'), "rb") as file:
output_lang = pickle.load(file)
encoder = torch.load(os.path.join(os.path.dirname(__file__), 'encoder (1).pth'), map_location=device)
decoder = torch.load(os.path.join(os.path.dirname(__file__), 'decoder (1).pth'), map_location=device)
input_vocab_size = input_lang.vocab_size
output_vocab_size = output_lang.vocab_size
def encode(s):
return [input_lang.char2index.get(ch, input_lang.char2index['$']) for ch in s]
def generate(input: List[str]) -> List[str]:
# pre-process the input: same length and max_length = 33
for i, inp in enumerate(input):
input[i] = input[i][:33] if len(input[i]) > 33 else input[i].ljust(33, '#')
input = torch.tensor([encode(i) for i in input], device=device, dtype=torch.long)
B, T = input.shape
encoder_output = encoder(input)
idx = torch.full((B, 1), 2, dtype=torch.long, device=device) # (B,1)
# idx is (B, T) array of indices in the current context
for _ in range(30):
# get the predictions
logits, loss = decoder(idx, encoder_output) # logits (B, T, vocab_size)
# focus only on the last time step
logits = logits[:, -1, :] # becomes (B, C)
# apply softmax to get probabilities
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (B, 1)
# append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
ans = []
for id in idx:
ans.append(output_lang.decode(id.tolist()[1:]).split('#', 1)[0])
return ans |