Spaces:
Runtime error
Runtime error
File size: 3,964 Bytes
2492536 fe1089d 5d99c07 fe1089d 5d99c07 fe1089d 2492536 fe1089d 2492536 fe1089d 2492536 fe1089d 2492536 fe1089d 2492536 fe1089d 2492536 fe1089d 2492536 fe1089d a597c76 fe1089d 2492536 fe1089d 2492536 fe1089d 5d99c07 a597c76 5d99c07 a597c76 1f063be a597c76 5d99c07 67a34bd 517fd4c 67a34bd 517fd4c 67a34bd 517fd4c 67a34bd 517fd4c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
# modelling util module providing formatting functions for model functionalities
# external imports
import torch
import gradio as gr
from transformers import BitsAndBytesConfig
# function that limits the prompt to contain model runtime
# tries to keep as much as possible, always keeping at least message and system prompt
def prompt_limiter(
tokenizer, message: str, history: list, system_prompt: str, knowledge: str = ""
):
# initializing the new prompt history empty
prompt_history = []
# getting the current token count for the message, system prompt, and knowledge
pre_count = (
token_counter(tokenizer, message)
+ token_counter(tokenizer, system_prompt)
+ token_counter(tokenizer, knowledge)
)
# validating the token count against threshold of 1024
# check if token count already too high without history
if pre_count > 1024:
# check if token count too high even without knowledge and history
if (
token_counter(tokenizer, message) + token_counter(tokenizer, system_prompt)
> 1024
):
# show warning and raise error
gr.Warning("Message and system prompt are too long. Please shorten them.")
raise RuntimeError(
"Message and system prompt are too long. Please shorten them."
)
# show warning and return with empty history and empty knowledge
gr.Warning("""
Input too long.
Knowledge and conversation history have been removed to keep model running.
""")
return message, prompt_history, system_prompt, ""
# if token count small enough, adding history bit by bit
if pre_count < 800:
# setting the count to the pre-count
count = pre_count
# reversing the history to prioritize recent conversations
history.reverse()
# iterating through the history
for conversation in history:
# checking the token count i´with the current conversation
count += token_counter(tokenizer, conversation[0]) + token_counter(
tokenizer, conversation[1]
)
# add conversation or break loop depending on token count
if count < 1024:
prompt_history.append(conversation)
else:
break
# return the message, adapted, system prompt, and knowledge
return message, prompt_history, system_prompt, knowledge
# token counter function using the model tokenizer
def token_counter(tokenizer, text: str):
# tokenize the text
tokens = tokenizer(text, return_tensors="pt").input_ids
# return the token count
return len(tokens[0])
# function to determine the device to use
def get_device():
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
return device
# function to set device config
# CREDIT: Copied from captum llama 2 example
# see https://captum.ai/tutorials/Llama2_LLM_Attribution
def gpu_loading_config(max_memory: str = "15000MB"):
n_gpus = torch.cuda.device_count()
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
return n_gpus, max_memory, bnb_config
# formatting mistral attention values
# CREDIT: copied from BERTViz
# see https://github.com/jessevig/bertviz
def format_mistral_attention(attention_values, layers=None, heads=None):
if layers:
attention_values = [attention_values[layer_index] for layer_index in layers]
squeezed = []
for layer_attention in attention_values:
layer_attention = layer_attention.squeeze(0)
if heads:
layer_attention = layer_attention[heads]
squeezed.append(layer_attention)
return torch.stack(squeezed)
|