Several issues loading and using the model with transformers==4.39.2

#7
by csegalin - opened
class LlavaMistralCaptioner:
    def __init__(self, device='cuda', 
                 hf_model="llava-hf/llava-v1.6-mistral-7b-hf", 
                 bf16=False, 
                 quant_force=True,
                 ):
        from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, BitsAndBytesConfig, AutoProcessor

        self.device = device
        
        if bf16:
            self.torch_type = torch.bfloat16
        else:
            self.torch_type = torch.float16

        
        with torch.cuda.device(self.device):
            _, total_bytes = torch.cuda.mem_get_info()
            total_gb = total_bytes / (1 << 30)
            if total_gb < 40:
                quant = True
            else:
                quant = False

        self.quantization_config = BitsAndBytesConfig(load_in_4bit=True,
                                                    bnb_4bit_quant_type="nf4",
                                                    bnb_4bit_compute_dtype=self.torch_type,
                                                        )
        print("========Use torch type as:{} with device:{}========\n".format(self.torch_type, self.device))

        self.model = LlavaNextForConditionalGeneration.from_pretrained(pretrained_model_name_or_path=hf_model, 
                                                                       torch_dtype=self.torch_type, 
                                                                       low_cpu_mem_usage=True,
                                                                       attn_implementation="flash_attention_2",
                                                                       quantization_config=self.quantization_config if quant or quant_force else None,
                                                                       # device_map='auto'
                                                                       ).eval()
        self.model.tie_weights()  
        # self.processor = AutoProcessor.from_pretrained(hf_model)
        self.processor = LlavaNextProcessor.from_pretrained(hf_model)

    def caption(self, image, 
                 prompt, 
                 max_tokens=225,
                 top_k=1, 
                 top_p=0.1, 
                 num_beams=1,
                 do_sample=True, 
                 temperature=0.1,
                 use_cache=True):
        import re
        prompt = f'''[INST] <image>\n {prompt} [/INST]'''
        
        inputs = self.processor(text=prompt, images=image, return_tensors="pt").to(self.device, self.torch_type)
        outputs = self.model.generate(**inputs,
                                    max_new_tokens=max_tokens, 
                                    top_k=top_k, 
                                    top_p=top_p, 
                                    num_beams=num_beams,
                                    do_sample=True if temperature > 0 else do_sample, 
                                    temperature=temperature,
                                    use_cache=use_cache,
                                    # pad_token_id=2, 
                                    # num_return_sequences=1
                                   )
        response = self.processor.decode(outputs[0],
                                        skip_special_tokens=True,
                                         clean_up_tokenization_spaces=False)
        response = response.split('[/INST]')[-1].strip()
        response = re.sub(r'\n+', ' ', response)
        response = response.strip().replace("</s>", "").replace("<s>", "").replace("*", " ")
        return response

1 when load the model I get
You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour
The model weights are not tied. Please use the tie_weights method before using the infer_auto_device function.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.

2 when generating the caption I get the same caption repeated 3 times

Any help on this?

Having the same issues with this model: llava-hf/llama3-llava-next-8b-hf

Llava Hugging Face org

Hey!

I am not 100% sure which arguments are used to run the script, but here is some common advice on LLaVa and FA2:

  • FA2 should be loaded in half precision which I am not sure if happening in your script. Also, in LLaVa specifically the recommended precision of fp16 which is the one used in original llava impl
  • mixing FA2 with quantization might result in weird/unexpected results, try using only one
  • Hmm, the tie_weights message actually shouldn't be raised, and usually you don't have to tie weights manually, as you have in the example script. Let me know one of the above advices help. If not can you share a fully reproducible code, that doesn;t rely on external args/hardware limits?

Sign up or log in to comment