wtoc_LoRA / README.md
cella110n's picture
Update README.md
5e396e0 verified
|
raw
history blame
4.46 kB
metadata
library_name: peft
base_model: Qwen/Qwen-VL-Chat

Model Card for Model ID

  • LoRA: wdtag -> long caption.

LICENSE: Tongyi Qianwen LICENSE

Model Details

  • Finetuned.

Model Description

  • Developed by: cella
  • Model type: LoRA
  • Language(s) (NLP): Eng
  • License: Tongyi Qianwen LICENSE
  • Finetuned from model [optional]: Qwen-VL-Chat

Uses

Model Load

LoRA_DIR = "/path-to-LoRA-dir"

if OPTION_VLM_METHOD == 'qwen_chat_LoRA':
        from peft import AutoPeftModelForCausalLM
        from transformers import AutoModelForCausalLM, AutoTokenizer
        from transformers.generation import GenerationConfig
        import torch
        torch.manual_seed(1234)

        # Note: The default behavior now has injection attack prevention off.
        tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL-Chat", trust_remote_code=True)
        \
        # use cuda device
        model = AutoPeftModelForCausalLM.from_pretrained(
                LoRA_DIR, # path to the output directory
                device_map="auto",
                trust_remote_code=True
        ).eval()

        # Specify hyperparameters for generation (No need to do this if you are using transformers>=4.32.0)
        model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-VL-Chat", trust_remote_code=True)
        
else:
    print("skipped.")

Captioning

if OPTION_VLM_METHOD == 'qwen_chat':
    from PIL import Image
    from langdetect import detect
    import string
    import re

    COMMON_QUERY = 'What is in tha image? Briefly describe the overall, in English'
    MORE_QUERY = 'What is in tha image? Describe the overall in detail, in English'
    LESS_QUERY = 'What is in tha image? Briefly summerize the description, in English'

    for image in dataset.images:       
        img_name = os.path.basename(image.path)
        img_name = os.path.splitext(img_name)[0]

        # すでにアウトプットフォルダに同名のtxtファイルが存在する場合はスキップ
        if OPTION_SKIP_EXISTING and os.path.exists(os.path.join(output_dir_VLM, img_name + '.txt')):
            clear_output(True)
            print("skipped: ", image.path)
            continue

        query = tokenizer.from_list_format([
            {'image': image.path },
            {'text': 'Make description using following words' + ', '.join(image.captions).replace('_', ' ') },
        ])
        response, history = model.chat(tokenizer, query=query, history=None)
            
        # ASCIIチェック、言語チェック、長さチェック
        retry_count = 0
        while not is_ascii(response) or not is_english(response) or not is_sufficient_length(response) or not is_over_length(response):
            clear_output(True)
            retry_count +=1
            print("Retry count:", retry_count)
            if retry_count >= 25 and is_ascii(response):
                break
            if not is_sufficient_length(response):
                print("Too short. Retry...")
                query = tokenizer.from_list_format([
                    {'image': image.path },
                    {'text': MORE_QUERY },
                ])
            if not is_over_length(response):
                print("Too long. Retry...")
                query = tokenizer.from_list_format([
                    {'image': image.path },
                    {'text': LESS_QUERY },
                ])
            if retry_count % 5 == 0:
                history = None
                query = tokenizer.from_list_format([
                    {'image': image.path },
                    {'text': COMMON_QUERY },
                ])
            response, history = model.chat(tokenizer, query=query, history=history)
            
        response = remove_fixed_patterns(response)

        if OPTION_SAVE_TAGS:
            # タグを保存
            with open(os.path.join(output_dir_VLM, img_name + '.txt'), 'w') as file:
                file.write(response)

        image.captions = response

        clear_output(True)

        print("Saved for ", image.path, ": ", response)
        
        #画像を表示
        img = Image.open(image.path)
        plt.imshow(np.asarray(img))
        plt.show()
        
else:
    print("skipped.")

Framework versions

  • PEFT 0.7.1