Output breaks above 16k context length.
I have been testing with mistral lite on various platforms and it seems that using text generation inference or transformers directly, the output starts to break when the context length exceeds a certain threshold -- roughly 16k. Is there any possible reason why this is the case since it should theoretically support 32k context length?
Hi @krecceg
Thanks for the feedback! Can you clarify if the break was out of memory or just output gibberish text? Cheers!
I am also getting this issue. I am instantiating and running my model using the snippet below:
model = AutoModelForCausalLM.from_pretrained(
"amazon/MistralLite",
torch_dtype=torch.bfloat16,
use_flash_attention_2=True,
device_map="auto"
)
tokeniser = AutoTokenizer.from_pretrained(
"amazon/MistralLite",
max_length=model_config.max_seq_len
)
pipe = pipeline(
"text-generation",
model=_model,
tokenizer=_tokeniser,
max_length=model_config.max_seq_len
)
pipe.padding_side = "left"
And then I am generating my output using this snippet:
with torch.no_grad():
output = pipe(
input,
return_full_text=False,
do_sample=False,
use_cache=True,
eos_token_id=tokeniser.eos_token_id,
num_return_sequences=1,
)
And when the input token length is above 16k I get this error:
File ~/.local/lib/python3.11/site-packages/transformers/models/mistral/modeling_mistral.py:374, in MistralFlashAttention2.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs)
371 past_value = past_value[:, :, slicing_tokens:, :].contiguous()
373 if past_key.shape[-2] != self.config.sliding_window - 1:
--> 374 raise ValueError(
375 f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
376 f" {past_key.shape}"
377 )
379 past_key_value = (past_key, past_value)
381 if attention_mask is not None:
ValueError: past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got torch.Size([1, 8, 8318, 128])
Unless I set this parameter
pipe.model.config.sliding_window = 32000
The question being, is this the intended behaviour? What is the ideal size of config.sliding_window? Why does it crash without this being explicitly set?
Hi @nicklikets Thanks for sharing this!
I was using the code snnipet below
from transformers import AutoModelForCausalLM, AutoTokenizer
import transformers
import torch
model_id = "amazon/MistralLite"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id,
torch_dtype=torch.bfloat16,
use_flash_attention_2=True,
device_map="auto",)
pipeline = transformers.pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
)
# replace with your long context 16K
prompt = "....................."
sequences = pipeline(
prompt,
max_new_tokens=400,
do_sample=False,
return_full_text=False,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
)
for seq in sequences:
print(f"{seq['generated_text']}")
It seems to be okay for input >16K, the transformers version I was using is 4.35.2
.
Can you double check if your code was configuring correctly? Or paste a more comprehensive code snippet here. Thank you!
Still getting that error from using your snippet on contexts longer that 16k tokens in length (specifically 24455 tokens).
I am using these package versions:
torch 2.1.1
transformers 4.35.2
flash-attn 2.3.6
And below is my code snippet, and the console output:
from transformers import AutoModelForCausalLM, AutoTokenizer
import transformers
import torch
model_id = "amazon/MistralLite"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id,
torch_dtype=torch.bfloat16,
use_flash_attention_2=True,
device_map="auto",)
pipeline = transformers.pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
)
with open("./data/longcontext.txt", "r") as f:
doc = f.read()
prompt = f"<|prompter|>{doc}, Summarise this document</s><|assistant|>"
tokens = tokenizer(prompt)
print(len(tokens[0]))
sequences = pipeline(
prompt,
max_new_tokens=400,
do_sample=False,
return_full_text=False,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
)
for seq in sequences:
print(f"{seq['generated_text']}")
----- CONSOLE OUTPUT -----
Loading checkpoint shards: 100%
24455
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[4], line 26
23 tokens = tokenizer(prompt)
24 print(len(tokens[0]))
---> 26 sequences = pipeline(
27 prompt,
28 max_new_tokens=400,
29 do_sample=False,
30 return_full_text=False,
31 num_return_sequences=1,
32 eos_token_id=tokenizer.eos_token_id,
33 )
34 for seq in sequences:
35 print(f"{seq['generated_text']}")
...
ValueError: past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got torch.Size([1, 8, 8312, 128])
Yes, I have the same error for context above the sliding window length of 16k.
The same error is there on Mistral models too when exceeding the sliding window length.
hi @nicklikets and @RonanMcGovern
Thank you for the feedback! If possible, can you provide the example long context file to me so I can reproduce the error?
Cheers!
The text file is here.
And the code to test is below.
FWIW:
- The same issue happens on Mistral v0.1 when running a context longer than 4k input tokens (the sliding window size for the model).
- I notice for Mistral v0.2 that the sliding window length is set to null (?)
# Installing required packages
!pip install -U -q transformers
!pip install -q -U accelerate
!pip install -U flash-attn -q
!pip install scipy -q
# Importing libraries
import transformers
import torch
import json
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from IPython.display import display, HTML
import gc
# Make sure to define these variables before using them
model_name_A = "amazon/MistralLite" # Replace with your actual model name
cache_dir = "" # Replace with your cache directory
# Model loading
model_A = AutoModelForCausalLM.from_pretrained(
model_name_A,
device_map='auto',
torch_dtype=torch.bfloat16,
use_flash_attention_2=True, # works with Llama models and reduces memory reqs
cache_dir=cache_dir)
tokenizer_A = AutoTokenizer.from_pretrained(model_name_A, cache_dir=cache_dir, use_fast=True)
# Function to generate text
def generate(model, tokenizer, user_prompt, system_prompt):
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>", "<</SYS>>"
prompt = f"{B_INST} {B_SYS}{system_prompt.strip()}{E_SYS}{user_prompt.strip()} {E_INST}\n\n"
inputs = tokenizer([prompt], return_tensors="pt").to('cuda')
shape = inputs.input_ids.shape
print(f"Length of input is {shape[1]}")
result = model.generate(**inputs, max_new_tokens=750, pad_token_id=tokenizer.eos_token_id, do_sample=False)
result_str = tokenizer.decode(result[0], skip_special_tokens=False)
torch.cuda.empty_cache()
gc.collect()
return result_str
# Reading from a text file
text_file = 'berkshire23.txt'
len_limit = int(16000*1*4) # Set the limit in characters
try:
with open(text_file, 'r') as file:
text = file.read()[:len_limit]
except FileNotFoundError:
print(f"File {text_file} not found.")
# Test with Model A
system_prompt = "" # Define your system prompt, or leave blank
user_prompt = f'Respond only with a brief summary of the below text.\n\n{text}\n\nRespond only with a brief summary of the above text.'
result = generate(model_A, tokenizer_A, user_prompt, system_prompt)
display(HTML(f"<b>{model_name_A}:</b><br>"))
print(result)
Hi @krecceg and @nicklikets
I have updated the sliding_window in the config.json to null
, and it should fix the issue you mentioned.
Since MistralLite
was fine tuned using the prompt like f"<|prompter|>{instruction}<|assistant|>". It is recommended to use prompt like below to get a valid summary as below:
with open("berkshire23.txt", "r") as fin:
text = fin.read()[:80000]
prompt = f"<|prompter|>{text}\n\nRespond only with a brief summary of the above text.</s><|assistant|>"
I tested, it should give some valid results even token size is >16000. Pls give a try and let me know if you still have further issues and thank you!
Thanks. Yeah that makes sense. Using null would have broken transformers but with the latest version that should work. And thanks for clarifying the prompt.
As an aside, probably you're aware that Mistral v0.2 now covers 32k context.