|
--- |
|
library_name: peft |
|
base_model: Qwen/Qwen-VL-Chat |
|
--- |
|
|
|
# Model Card for Model ID |
|
|
|
<!-- Provide a quick summary of what the model is/does. --> |
|
|
|
- LoRA: wdtag -> long caption. |
|
|
|
## Model Details |
|
|
|
- Finetuned. |
|
|
|
### Model Description |
|
|
|
<!-- Provide a longer summary of what this model is. --> |
|
|
|
- **Developed by:** cella] |
|
- **Model type:** LoRA |
|
- **Language(s) (NLP):** Eng |
|
- **License:** Tongyi Qianwen LICENSE |
|
- **Finetuned from model [optional]:** Qwen-VL-Chat |
|
|
|
## Uses |
|
|
|
|
|
### Model Load |
|
``` |
|
LoRA_DIR = "/path-to-LoRA-dir" |
|
|
|
if OPTION_VLM_METHOD == 'qwen_chat_LoRA': |
|
from peft import AutoPeftModelForCausalLM |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from transformers.generation import GenerationConfig |
|
import torch |
|
torch.manual_seed(1234) |
|
|
|
# Note: The default behavior now has injection attack prevention off. |
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL-Chat", trust_remote_code=True) |
|
\ |
|
# use cuda device |
|
model = AutoPeftModelForCausalLM.from_pretrained( |
|
LoRA_DIR, # path to the output directory |
|
device_map="auto", |
|
trust_remote_code=True |
|
).eval() |
|
|
|
# Specify hyperparameters for generation (No need to do this if you are using transformers>=4.32.0) |
|
model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-VL-Chat", trust_remote_code=True) |
|
|
|
else: |
|
print("skipped.") |
|
``` |
|
|
|
### Captioning |
|
``` |
|
if OPTION_VLM_METHOD == 'qwen_chat': |
|
from PIL import Image |
|
from langdetect import detect |
|
import string |
|
import re |
|
|
|
COMMON_QUERY = 'What is in tha image? Briefly describe the overall, in English' |
|
MORE_QUERY = 'What is in tha image? Describe the overall in detail, in English' |
|
LESS_QUERY = 'What is in tha image? Briefly summerize the description, in English' |
|
|
|
for image in dataset.images: |
|
img_name = os.path.basename(image.path) |
|
img_name = os.path.splitext(img_name)[0] |
|
|
|
# すでにアウトプットフォルダに同名のtxtファイルが存在する場合はスキップ |
|
if OPTION_SKIP_EXISTING and os.path.exists(os.path.join(output_dir_VLM, img_name + '.txt')): |
|
clear_output(True) |
|
print("skipped: ", image.path) |
|
continue |
|
|
|
query = tokenizer.from_list_format([ |
|
{'image': image.path }, |
|
{'text': 'Make description using following words' + ', '.join(image.captions).replace('_', ' ') }, |
|
]) |
|
response, history = model.chat(tokenizer, query=query, history=None) |
|
|
|
# ASCIIチェック、言語チェック、長さチェック |
|
retry_count = 0 |
|
while not is_ascii(response) or not is_english(response) or not is_sufficient_length(response) or not is_over_length(response): |
|
clear_output(True) |
|
retry_count +=1 |
|
print("Retry count:", retry_count) |
|
if retry_count >= 25 and is_ascii(response): |
|
break |
|
if not is_sufficient_length(response): |
|
print("Too short. Retry...") |
|
query = tokenizer.from_list_format([ |
|
{'image': image.path }, |
|
{'text': MORE_QUERY }, |
|
]) |
|
if not is_over_length(response): |
|
print("Too long. Retry...") |
|
query = tokenizer.from_list_format([ |
|
{'image': image.path }, |
|
{'text': LESS_QUERY }, |
|
]) |
|
if retry_count % 5 == 0: |
|
history = None |
|
query = tokenizer.from_list_format([ |
|
{'image': image.path }, |
|
{'text': COMMON_QUERY }, |
|
]) |
|
response, history = model.chat(tokenizer, query=query, history=history) |
|
|
|
response = remove_fixed_patterns(response) |
|
|
|
if OPTION_SAVE_TAGS: |
|
# タグを保存 |
|
with open(os.path.join(output_dir_VLM, img_name + '.txt'), 'w') as file: |
|
file.write(response) |
|
|
|
image.captions = response |
|
|
|
clear_output(True) |
|
|
|
print("Saved for ", image.path, ": ", response) |
|
|
|
#画像を表示 |
|
img = Image.open(image.path) |
|
plt.imshow(np.asarray(img)) |
|
plt.show() |
|
|
|
else: |
|
print("skipped.") |
|
``` |
|
|
|
### Framework versions |
|
|
|
- PEFT 0.7.1 |