Spaces:
Paused
Paused
from config import ModelArgs | |
from model import Llama | |
import torch | |
import torch.nn.functional as F | |
from tokenizer import Tokenizer | |
import argparse | |
tokenizer = Tokenizer() | |
tokenizer = tokenizer.ready_tokenizer() | |
def remove_prefix(state_dict, prefix): | |
new_state_dict = {} | |
for key, value in state_dict.items(): | |
if key.startswith(prefix): | |
new_key = key[len(prefix):] # Remove the prefix | |
new_state_dict[new_key] = value | |
else: | |
new_state_dict[key] = value | |
return new_state_dict | |
def topk_sampling(model, prompt, device, max_length=50, top_k=50, temperature=1.0): | |
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) | |
generated_tokens = [] | |
ModelArgs.inference=True | |
for _ in range(max_length): | |
with torch.no_grad(): | |
outputs = model(input_ids) | |
logits = outputs[:, -1, :] | |
probs = F.softmax(logits, dim=-1) | |
# Top-k filtering | |
top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1) | |
# Apply temperature scaling | |
# probs = probs / temperature | |
# Sample from top-k | |
next_token = torch.multinomial(top_k_probs, num_samples=1) | |
# generated_tokens.append(next_token.item()) | |
xcol = torch.gather(top_k_indices, -1, next_token) | |
input_ids = torch.cat([input_ids, xcol], dim=1) #1 because is it the dimension of the sequence | |
return tokenizer.decode(input_ids[0], skip_special_tokens=True) | |
def main(): | |
torch.set_float32_matmul_precision('high') | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--prompt", type=str, default="Once upon a time") | |
parser.add_argument("--max_length", type=int, default=128) | |
parser.add_argument("--temperature", type=float, default=1.0) | |
parser.add_argument("--top_k", type=int, default=50) | |
# parser.add_argument("--repetition_penalty", type=float, default=1.2) | |
args = parser.parse_args() | |
model = Llama(device=ModelArgs.device, embeddings_dims=ModelArgs.embeddings_dims, no_of_decoder_layers=ModelArgs.no_of_decoder_layers, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout) | |
# model = torch.compile(model) | |
model = model.to(ModelArgs.device) | |
dict_model = torch.load('weights/pretrained/snapshot_4650.pt') | |
dict_model['MODEL_STATE'] = remove_prefix(dict_model['MODEL_STATE'], '_orig_mod.') | |
model.load_state_dict(dict_model['MODEL_STATE']) | |
model.eval() | |
print("Model ready") | |
# prompt = 'Its a secret' | |
with torch.no_grad(): | |
generated_text = topk_sampling(model, args.prompt, max_length=args.max_length, top_k=50, temperature=args.temperature, device=ModelArgs.device) | |
print("Gnerated: ", generated_text) | |
# generated_text = beam_search(model, tokenizer, args.prompt, beam_width=5, max_length=50, temperature=1.0) | |
print(args.prompt + generated_text) | |
if __name__ == '__main__': | |
main() |