File size: 3,964 Bytes
2492536
fe1089d
 
5d99c07
fe1089d
5d99c07
fe1089d
 
2492536
 
fe1089d
 
 
2492536
fe1089d
2492536
fe1089d
 
 
 
 
 
2492536
 
fe1089d
 
2492536
fe1089d
 
 
 
 
 
 
 
 
 
 
2492536
 
 
 
 
fe1089d
 
2492536
fe1089d
a597c76
fe1089d
 
 
 
 
 
 
2492536
fe1089d
 
 
 
 
 
 
 
 
 
2492536
fe1089d
 
 
 
 
 
 
 
 
5d99c07
 
a597c76
5d99c07
 
 
 
 
 
 
 
a597c76
1f063be
a597c76
5d99c07
 
 
 
 
 
 
 
 
67a34bd
 
 
517fd4c
67a34bd
517fd4c
 
 
67a34bd
 
 
517fd4c
 
67a34bd
517fd4c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# modelling util module providing formatting functions for model functionalities

# external imports
import torch
import gradio as gr
from transformers import BitsAndBytesConfig


# function that limits the prompt to contain model runtime
# tries to keep as much as possible, always keeping at least message and system prompt
def prompt_limiter(
    tokenizer, message: str, history: list, system_prompt: str, knowledge: str = ""
):
    # initializing the new prompt history empty
    prompt_history = []
    # getting the current token count for the message, system prompt, and knowledge
    pre_count = (
        token_counter(tokenizer, message)
        + token_counter(tokenizer, system_prompt)
        + token_counter(tokenizer, knowledge)
    )

    # validating the token count against threshold of 1024
    # check if token count already too high without history
    if pre_count > 1024:

        # check if token count too high even without knowledge and history
        if (
            token_counter(tokenizer, message) + token_counter(tokenizer, system_prompt)
            > 1024
        ):

            # show warning and raise error
            gr.Warning("Message and system prompt are too long. Please shorten them.")
            raise RuntimeError(
                "Message and system prompt are too long. Please shorten them."
            )

        # show warning and return with empty history and empty knowledge
        gr.Warning("""
                   Input too long.
                   Knowledge and conversation history have been removed to keep model running.
                   """)
        return message, prompt_history, system_prompt, ""

    # if token count small enough, adding history bit by bit
    if pre_count < 800:
        # setting the count to the pre-count
        count = pre_count
        # reversing the history to prioritize recent conversations
        history.reverse()

        # iterating through the history
        for conversation in history:

            # checking the token count i´with the current conversation
            count += token_counter(tokenizer, conversation[0]) + token_counter(
                tokenizer, conversation[1]
            )

            # add conversation or break loop depending on token count
            if count < 1024:
                prompt_history.append(conversation)
            else:
                break

    # return the message, adapted, system prompt, and knowledge
    return message, prompt_history, system_prompt, knowledge


# token counter function using the model tokenizer
def token_counter(tokenizer, text: str):
    # tokenize the text
    tokens = tokenizer(text, return_tensors="pt").input_ids
    # return the token count
    return len(tokens[0])


# function to determine the device to use
def get_device():
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    return device


# function to set device config
# CREDIT: Copied from captum llama 2 example
# see https://captum.ai/tutorials/Llama2_LLM_Attribution
def gpu_loading_config(max_memory: str = "15000MB"):
    n_gpus = torch.cuda.device_count()
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )
    return n_gpus, max_memory, bnb_config


# formatting mistral attention values
# CREDIT: copied from BERTViz
# see https://github.com/jessevig/bertviz
def format_mistral_attention(attention_values, layers=None, heads=None):
    if layers:
        attention_values = [attention_values[layer_index] for layer_index in layers]
    squeezed = []
    for layer_attention in attention_values:
        layer_attention = layer_attention.squeeze(0)
        if heads:
            layer_attention = layer_attention[heads]
        squeezed.append(layer_attention)
    return torch.stack(squeezed)