File size: 3,328 Bytes
72268ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from model import ExLlama, ExLlamaCache, ExLlamaConfig
from tokenizer import ExLlamaTokenizer
from generator import ExLlamaGenerator
import torch
import torch.nn.functional as F
import os, glob
import cuda_ext

# Directory containing model, tokenizer, generator

model_directory =  "/mnt/str/models/_test_models/TheBloke_Llama-2-13B-chat-GPTQ/"

# Locate files we need within that directory

tokenizer_path = os.path.join(model_directory, "tokenizer.model")
model_config_path = os.path.join(model_directory, "config.json")
st_pattern = os.path.join(model_directory, "*.safetensors")
model_path = glob.glob(st_pattern)

# Create config, model, tokenizer and generator

config = ExLlamaConfig(model_config_path)               # create config from config.json
config.model_path = model_path                          # supply path to model weights file

model = ExLlama(config)                                 # create ExLlama instance and load the weights
tokenizer = ExLlamaTokenizer(tokenizer_path)            # create tokenizer from tokenizer model file

cache = ExLlamaCache(model, batch_size = 2)             # create cache for inference
generator = ExLlamaGenerator(model, tokenizer, cache)   # create generator

# Configure generator

generator.settings.token_repetition_penalty_max = 1.15
generator.settings.temperature = 0.95
generator.settings.top_k = 40
generator.settings.top_p = 0.75
# generator.settings.typical = 0.95

# Prompts to mix

f1 = \
"""[INST] <<SYS>>
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
<</SYS>>
{prompt}[/INST]"""

f2 = \
"""[INST] <<SYS>>
<</SYS>>
You are a rude and obnoxious assistant. You hate everything and everyone.
{prompt}[/INST]"""


prompts = \
[
    f1.replace("{prompt}", "Tell me about Homer Simpson"),
    f2.replace("{prompt}", "Tell me about Homer Simpson"),
]

def generate_cfg(prompts, alpha, max_new_tokens):

    ids, mask = tokenizer.encode(prompts, return_mask = True)
    generator.gen_begin(ids, mask = mask)

    # Sampling loop

    for _ in range(max_new_tokens):

        logits = model.forward(generator.sequence[:, -1:], cache, input_mask = mask)
        generator.apply_rep_penalty(logits)

        logits = F.log_softmax(logits, dim = -1)
        logits_mixed = (1 - alpha) * logits[0] + alpha * logits[1]

        sampled_token, _ = generator.sample_current(logits_mixed)
        if sampled_token.item() == tokenizer.eos_token_id: break

        batch_token = sampled_token.repeat(2, 1)
        generator.gen_accept_token(batch_token)

    output = tokenizer.decode(generator.sequence[0])
    return output

for i in range(10):

    alpha = i / 5.0 - 0.4
    print()
    print(f"--------------------------------------")
    print(f"alpha = {alpha:.1f}")
    print(f"--------------------------------------")
    output = generate_cfg(prompts, alpha, 200)
    print(output[len(prompts[0]):].strip())