StoryLlama / inference.py
YuvrajSingh9886's picture
Upload 12 files
5bb6ad4 verified
from config import ModelArgs
from model import Llama
import torch
import torch.nn.functional as F
from tokenizer import Tokenizer
import argparse
tokenizer = Tokenizer()
tokenizer = tokenizer.ready_tokenizer()
def remove_prefix(state_dict, prefix):
new_state_dict = {}
for key, value in state_dict.items():
if key.startswith(prefix):
new_key = key[len(prefix):] # Remove the prefix
new_state_dict[new_key] = value
else:
new_state_dict[key] = value
return new_state_dict
def topk_sampling(model, prompt, device, max_length=50, top_k=50, temperature=1.0):
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
generated_tokens = []
ModelArgs.inference=True
for _ in range(max_length):
with torch.no_grad():
outputs = model(input_ids)
logits = outputs[:, -1, :]
probs = F.softmax(logits, dim=-1)
# Top-k filtering
top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1)
# Apply temperature scaling
# probs = probs / temperature
# Sample from top-k
next_token = torch.multinomial(top_k_probs, num_samples=1)
# generated_tokens.append(next_token.item())
xcol = torch.gather(top_k_indices, -1, next_token)
input_ids = torch.cat([input_ids, xcol], dim=1) #1 because is it the dimension of the sequence
return tokenizer.decode(input_ids[0], skip_special_tokens=True)
def main():
torch.set_float32_matmul_precision('high')
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str, default="Once upon a time")
parser.add_argument("--max_length", type=int, default=128)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_k", type=int, default=50)
# parser.add_argument("--repetition_penalty", type=float, default=1.2)
args = parser.parse_args()
model = Llama(device=ModelArgs.device, embeddings_dims=ModelArgs.embeddings_dims, no_of_decoder_layers=ModelArgs.no_of_decoder_layers, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout)
# model = torch.compile(model)
model = model.to(ModelArgs.device)
dict_model = torch.load('weights/pretrained/snapshot_4650.pt')
dict_model['MODEL_STATE'] = remove_prefix(dict_model['MODEL_STATE'], '_orig_mod.')
model.load_state_dict(dict_model['MODEL_STATE'])
model.eval()
print("Model ready")
# prompt = 'Its a secret'
with torch.no_grad():
generated_text = topk_sampling(model, args.prompt, max_length=args.max_length, top_k=50, temperature=args.temperature, device=ModelArgs.device)
print("Gnerated: ", generated_text)
# generated_text = beam_search(model, tokenizer, args.prompt, beam_width=5, max_length=50, temperature=1.0)
print(args.prompt + generated_text)
if __name__ == '__main__':
main()