File size: 4,996 Bytes
ca67c09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from transformers import LlamaForCausalLM, LlamaConfig, AutoTokenizer
import torch
import os

# huggingface-cli download meta-llama/Llama-3.1-8B tokenizer_config.json --local-dir ./
# huggingface-cli download meta-llama/Llama-3.1-8B tokenizer.json --local-dir ./
# huggingface-cli download meta-llama/Llama-3.1-8B special_tokens_map.json --local-dir ./

question = "A $y$-intercept is a point on the graph that lies on the $y$-axis, so $x = 0$. Hence, the number $y$-intercepts corresponds to the number of real solutions of the quadratic equation $y^2 - 4y - 1 = 0$. The discriminant of this quadratic equation is $(-4)^2 + 4 \cdot 1 \cdot (-1) = 20$, which is positive, so the quadratic has two distinct real roots. Therefore, the number of $y$-intercepts is $\boxed{2}$. \n  \n [asy] \n size(150); \n real ticklen=3; \n real tickspace=2; \n  \n real ticklength=0.1cm; \n real axisarrowsize=0.14cm; \n pen axispen=black+1.3bp; \n real vectorarrowsize=0.2cm; \n real tickdown=-0.5; \n real tickdownlength=-0.15inch; \n real tickdownbase=0.3; \n real wholetickdown=tickdown; \n void rr_cartesian_axes(real xleft, real xright, real ybottom, real ytop, real xstep=1, real ystep=1, bool \n  \n useticks=false, bool complexplane=false, bool usegrid=true) { \n  \n import graph; \n  \n real i; \n  \n if(complexplane) { \n  \n label('$\textnormal{Re}$',(xright,0),SE); \n  \n label('$\textnormal{Im}$',(0,ytop),NW); \n  \n } else { \n  \n label('$x$',(xright+0.4,-0.5)); \n  \n label('$y$',(-0.5,ytop+0.2)); \n  \n } \n  \n ylimits(ybottom,ytop); \n  \n xlimits( xleft, xright); \n  \n real[] TicksArrx,TicksArry; \n  \n for(i=xleft+xstep; i<xright; i+=xstep) { \n  \n if(abs(i) >0.1) { \n  \n TicksArrx.push(i); \n  \n } \n  \n } \n  \n for(i=ybottom+ystep; i<ytop; i+=ystep) { \n  \n if(abs(i) >0.1) { \n  \n TicksArry.push(i); \n  \n } \n  \n } \n  \n if(usegrid) {"    
predictor_load_path = "/home/ya255/projects/TokenButler/expt_model/TrainTokenButler_42_finetune_None_None_500_llama_meta-llama_Llama-3.1-8B_L3_8B_1k.csv_L3_8B_1k_Cont_False_False_2000_False_redpajama_1024_1_1_20_0.001_1024/16_False_4_1000_ExpPred_fixed_40pc_True_False_0_None_False_False_4_8_2_32_1024_False_False_True_32_0.3875000000000002__best.pt"
base_model_name = "meta-llama/Llama-3.1-8B"

def get_producer_layers(model):
    """
    Traverses the model to find the producer layer (layer_idx=0).cc
    """
    producer_modules = []
    for module in model.modules():
        if module.__class__.__name__.endswith("AttentionExperimental") and module.layer_idx == 0:
            producer_modules.append(module)
    return producer_modules
    
# 1) Load the base model from HF
base_model = LlamaForCausalLM.from_pretrained(base_model_name, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
inputs = tokenizer(question, return_tensors="pt")
inputs = {k: v.to(base_model.device) for k, v in inputs.items()}
question_length = inputs['attention_mask'].shape[1]

with torch.no_grad():
    base_output_ids = base_model.generate(
        **inputs,
        max_new_tokens=200,
        do_sample=True,
        top_p=0.95,
        temperature=0.7,
    )
base_output_text = tokenizer.decode(base_output_ids[0][question_length:], skip_special_tokens=True)

# Remove base model from GPU
base_model_device = base_model.device
base_model.to("cpu")
base_state_dict = base_model.state_dict()
del base_model
torch.cuda.empty_cache()

from modeling_llama_butler import LlamaButlerConfig, LlamaButlerForCausalLM
butler_config = LlamaButlerConfig.from_pretrained('config.json')

butler_model = LlamaButlerForCausalLM(butler_config)
butler_model.load_state_dict(base_state_dict, strict=False)

model_producer_layers = get_producer_layers(butler_model)
producer_layer_weights = torch.load(predictor_load_path)
for idx, producer_layer_weight in enumerate(producer_layer_weights):
    try:
        model_producer_layers[idx].load_state_dict(producer_layer_weight, strict=False)
    except Exception as e:
        print(f"Error loading producer layer {idx}: {e}")
        print("\n\nContinuing... !! Bad Perf If Unintentional !!\n\n")


butler_model.to(base_model_device)
butler_model.eval()

with torch.no_grad():
    butler_output_ids = butler_model.generate(
        **inputs,
        max_new_tokens=200,
        do_sample=True,
        top_p=0.95,
        temperature=0.7,
    )

butler_output_text = tokenizer.decode(butler_output_ids[0][question_length:], skip_special_tokens=True)

print("\n=== Base Model Output (Newlines Removed For Brevity) ===\n")
print(base_output_text.replace("\n", ""))
print("\n")
print("=== Butler Model Output (Newlines Removed For Brevity) ===\n")
print(butler_output_text.replace("\n", ""))
print("\n")

OUTPUT_DIR = "."
print(f"\nSaving final merged model to: {OUTPUT_DIR}")
butler_model.save_pretrained(OUTPUT_DIR, safe_serialization=False)

# tokenizer.save_pretrained(OUTPUT_DIR)
print("\nAll done! The folder should now have `pytorch_model.bin` and the updated `config.json`.\n")