File size: 3,070 Bytes
6a867d4 137c45d 6a867d4 e436a25 6a867d4 137c45d 6a867d4 137c45d 6a867d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/07_gpt_to_llama/standalone-llama32.ipynb
import os
import time
import urllib.request
import torch
from model import Llama3Model, generate, text_to_token_ids, token_ids_to_text
from tokenizer import Llama3Tokenizer, ChatFormat, clean_text
#######################################
# Model settings
MODEL_FILE = "llama3.2-1B-instruct.pth"
# MODEL_FILE = "llama3.2-1B-base.pth"
# MODEL_FILE = "llama3.2-3B-instruct.pth"
# MODEL_FILE = "llama3.2-3B-base.pth"
MODEL_CONTEXT_LENGTH = 8192 # Supports up to 131_072
# Text generation settings
if "instruct" in MODEL_FILE:
PROMPT = "What do llamas eat?"
else:
PROMPT = "Llamas eat"
MAX_NEW_TOKENS = 150
TEMPERATURE = 0.
TOP_K = 1
#######################################
###################################
# Initialize model
##################################
url = f"https://huggingface.co/rasbt/llama-3.2-from-scratch/resolve/main/{MODEL_FILE}"
if not os.path.exists(MODEL_FILE):
print(f"Downloading {MODEL_FILE}...")
urllib.request.urlretrieve(url, MODEL_FILE)
print(f"Downloaded to {MODEL_FILE}")
if "1B" in MODEL_FILE:
from model import LLAMA32_CONFIG_1B as LLAMA32_CONFIG
elif "3B" in MODEL_FILE:
from model import LLAMA32_CONFIG_3B as LLAMA32_CONFIG
else:
raise ValueError("Incorrect model file name")
LLAMA32_CONFIG["context_length"] = MODEL_CONTEXT_LENGTH
model = Llama3Model(LLAMA32_CONFIG)
model.load_state_dict(torch.load(MODEL_FILE, weights_only=True))
device = (
torch.device("cuda") if torch.cuda.is_available() else
torch.device("mps") if torch.backends.mps.is_available() else
torch.device("cpu")
)
model.to(device)
###################################
# Initialize tokenizer
##################################
TOKENIZER_FILE = "tokenizer.model"
url = f"https://huggingface.co/rasbt/llama-3.2-from-scratch/resolve/main/{TOKENIZER_FILE}"
if not os.path.exists(TOKENIZER_FILE):
urllib.request.urlretrieve(url, TOKENIZER_FILE)
print(f"Downloaded to {TOKENIZER_FILE}")
tokenizer = Llama3Tokenizer("tokenizer.model")
if "instruct" in MODEL_FILE:
tokenizer = ChatFormat(tokenizer)
###################################
# Generate text
##################################
torch.manual_seed(123)
start = time.time()
token_ids = generate(
model=model,
idx=text_to_token_ids(PROMPT, tokenizer).to(device),
max_new_tokens=MAX_NEW_TOKENS,
context_size=LLAMA32_CONFIG["context_length"],
top_k=TOP_K,
temperature=TEMPERATURE
)
print(f"Time: {time.time() - start:.2f} sec")
if torch.cuda.is_available():
max_mem_bytes = torch.cuda.max_memory_allocated()
max_mem_gb = max_mem_bytes / (1024 ** 3)
print(f"Max memory allocated: {max_mem_gb:.2f} GB")
output_text = token_ids_to_text(token_ids, tokenizer)
if "instruct" in MODEL_FILE:
output_text = clean_text(output_text)
print("\n\nOutput text:\n\n", output_text)
|