metadata
library_name: peft
base_model: Qwen/Qwen-VL-Chat
Model Card for Model ID
- LoRA: wdtag -> long caption.
Model Details
- Finetuned.
Model Description
- Developed by: cella]
- Model type: LoRA
- Language(s) (NLP): Eng
- License: Tongyi Qianwen LICENSE
- Finetuned from model [optional]: Qwen-VL-Chat
Uses
Model Load
LoRA_DIR = "/path-to-LoRA-dir"
if OPTION_VLM_METHOD == 'qwen_chat_LoRA':
from peft import AutoPeftModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
import torch
torch.manual_seed(1234)
# Note: The default behavior now has injection attack prevention off.
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL-Chat", trust_remote_code=True)
\
# use cuda device
model = AutoPeftModelForCausalLM.from_pretrained(
LoRA_DIR, # path to the output directory
device_map="auto",
trust_remote_code=True
).eval()
# Specify hyperparameters for generation (No need to do this if you are using transformers>=4.32.0)
model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-VL-Chat", trust_remote_code=True)
else:
print("skipped.")
Captioning
if OPTION_VLM_METHOD == 'qwen_chat':
from PIL import Image
from langdetect import detect
import string
import re
COMMON_QUERY = 'What is in tha image? Briefly describe the overall, in English'
MORE_QUERY = 'What is in tha image? Describe the overall in detail, in English'
LESS_QUERY = 'What is in tha image? Briefly summerize the description, in English'
for image in dataset.images:
img_name = os.path.basename(image.path)
img_name = os.path.splitext(img_name)[0]
# すでにアウトプットフォルダに同名のtxtファイルが存在する場合はスキップ
if OPTION_SKIP_EXISTING and os.path.exists(os.path.join(output_dir_VLM, img_name + '.txt')):
clear_output(True)
print("skipped: ", image.path)
continue
query = tokenizer.from_list_format([
{'image': image.path },
{'text': 'Make description using following words' + ', '.join(image.captions).replace('_', ' ') },
])
response, history = model.chat(tokenizer, query=query, history=None)
# ASCIIチェック、言語チェック、長さチェック
retry_count = 0
while not is_ascii(response) or not is_english(response) or not is_sufficient_length(response) or not is_over_length(response):
clear_output(True)
retry_count +=1
print("Retry count:", retry_count)
if retry_count >= 25 and is_ascii(response):
break
if not is_sufficient_length(response):
print("Too short. Retry...")
query = tokenizer.from_list_format([
{'image': image.path },
{'text': MORE_QUERY },
])
if not is_over_length(response):
print("Too long. Retry...")
query = tokenizer.from_list_format([
{'image': image.path },
{'text': LESS_QUERY },
])
if retry_count % 5 == 0:
history = None
query = tokenizer.from_list_format([
{'image': image.path },
{'text': COMMON_QUERY },
])
response, history = model.chat(tokenizer, query=query, history=history)
response = remove_fixed_patterns(response)
if OPTION_SAVE_TAGS:
# タグを保存
with open(os.path.join(output_dir_VLM, img_name + '.txt'), 'w') as file:
file.write(response)
image.captions = response
clear_output(True)
print("Saved for ", image.path, ": ", response)
#画像を表示
img = Image.open(image.path)
plt.imshow(np.asarray(img))
plt.show()
else:
print("skipped.")
Framework versions
- PEFT 0.7.1