remove flash_attn dependency for macos / non-gpu machines

#9
by ursnation - opened

Found a solution in this thread to remove the dependency for flash_attn on macos (non-gpu environments) and adapted it for this model.

from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import os
from unittest.mock import patch
from transformers.dynamic_module_utils import get_imports
import torch

# set device
device = torch.device("cuda" if torch.cuda.is_available() else "mps")

def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
    """Work around for flash_attn on MiniCPM 2.6 code example"""
    imports = get_imports(filename)
    if not torch.cuda.is_available() and "flash_attn" in imports:
        imports.remove("flash_attn")
    return imports

model_name = 'openbmb/MiniCPM-V-2_6'

# create model
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        device_map='mps', # mps for macos gpu cores
        trust_remote_code=True
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        trust_remote_code=True
    )

image = Image.open('xx.png').convert('RGB')

question = 'Can you give me the text from this image into json format?'

msgs = [{'role': 'user', 'content': [image, question]}]

# res = model.chat(
#     image=None,
#     msgs=msgs,
#     tokenizer=tokenizer
# )
# print(res)

res = model.chat(
    image=None,
    msgs=msgs,
    tokenizer=tokenizer,
    sampling=True,
    stream=True
)

generated_text = ""
for new_text in res:
    generated_text += new_text
    print(new_text, flush=True, end='')

Cheers!

thank you very much

This comment has been hidden

Sign up or log in to comment