Spaces:
Paused
Paused
from config import ModelArgs | |
from model import Llama | |
import torch | |
import torch.nn.functional as F | |
from tokenizer import Tokenizer | |
import argparse | |
tokenizer = Tokenizer() | |
tokenizer = tokenizer.ready_tokenizer() | |
def remove_hashtag_lines(text): | |
"""Removes lines that contain hashtags from the given text.""" | |
lines = text.split("\n") | |
cleaned_lines = [line for line in lines if "#" not in line] | |
return "\n".join(cleaned_lines) | |
def remove_prefix(state_dict, prefix): | |
new_state_dict = {} | |
for key, value in state_dict.items(): | |
if key.startswith(prefix): | |
new_key = key[len(prefix):] # Remove the prefix | |
new_state_dict[new_key] = value | |
else: | |
new_state_dict[key] = value | |
return new_state_dict | |
def topk_sampling(model, prompt, device, max_length=50, top_k=50, temperature=1.0, frequency_penalty=0.5): | |
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) | |
# generated_tokens = [] # Store generated tokens | |
token_frequencies = {} # Track token counts | |
for step in range(max_length): | |
with torch.no_grad(): | |
outputs = model(input_ids) | |
logits = outputs[:, -1, :] # Get logits for next token | |
logits = logits / temperature | |
# # Step 1: Apply frequency penalty ONLY AFTER the first token is generated | |
if step > 0: # Skip penalty on first step | |
for token in input_ids[0].tolist(): | |
token_frequencies[token] = token_frequencies.get(token, 0) + 1 # Count occurrences | |
# Modify logits AFTER counting | |
for token, freq in token_frequencies.items(): | |
logits[0, token] -= frequency_penalty * (freq ** 0.8) # Apply soft penalty | |
# Convert logits to probabilities | |
probs = F.softmax(logits, dim=-1) | |
# Top-k filtering | |
top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1) | |
# Apply temperature scaling | |
# probs = probs / temperature | |
# Sample from top-k | |
next_token = torch.multinomial(top_k_probs, num_samples=1) | |
# if next_token.item() == tokenizer.eos_token_id: | |
# break # Stop if EOS token is generated | |
# Store generated token AFTER sampling | |
# token_id = next_token.item() | |
# generated_tokens.append(token_id) | |
# Update input_ids for next step | |
xcol = torch.gather(top_k_indices, -1, next_token) | |
if xcol == tokenizer.eos_token_id: | |
break | |
# generated_tokens.append(xcol) | |
input_ids = torch.cat([input_ids, xcol], dim=1) | |
# Decode only the generated tokens | |
return tokenizer.decode(input_ids[0], skip_special_tokens=True) | |
def main(): | |
# torch.set_float32_matmul_precision('high') | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--prompt", type=str, default=''' Follow the given instructions carefully. My mom is about to retire from her 10 long years of service to a company. write me a message saying how grateful we are for her service to our company. ''') | |
parser.add_argument("--max_length", type=int, default=256) | |
parser.add_argument("--temperature", type=float, default=0.8) | |
# parser.add_argument("--repetition_penalty", type=float, default=1.2) | |
args = parser.parse_args() | |
model = Llama(device=ModelArgs.device, embeddings_dims=ModelArgs.embeddings_dims, no_of_decoder_layers=ModelArgs.no_of_decoder_layers, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout) | |
# model = torch.compile(model) | |
model = model.to(ModelArgs.device) | |
dict_model = torch.load('DPO_model_1650.pt') | |
dict_model['MODEL_STATE'] = remove_prefix(dict_model['MODEL_STATE'], '_orig_mod.') | |
model.load_state_dict(dict_model['MODEL_STATE']) | |
model.eval() | |
print("Model ready") | |
# prompt = 'Its a secret' | |
with torch.no_grad(): | |
generated_text = topk_sampling(model, args.prompt, max_length=args.max_length, top_k=args.top_k, temperature=args.temperature, device=ModelArgs.device) | |
# generated_text = remove_hashtag_lines(generated_text) | |
print("Generated: ", generated_text) | |
# generated_text = beam_search(model, tokenizer, args.prompt, beam_width=5, max_length=50, temperature=1.0) | |
# print(args.prompt + generated_text) | |
if __name__ == '__main__': | |
main() | |