Minthy commited on
Commit
2de8d8c
1 Parent(s): 1892e52

Upload 2 files

Browse files
batch_processing_example.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoProcessor, AutoModelForVision2Seq
3
+ from transformers.image_utils import load_image
4
+ from tqdm import tqdm
5
+ from pathlib import Path
6
+ from os.path import join as opj
7
+ from os import listdir
8
+
9
+ model_name_or_path="Minthy/ToriiGate-v0.3"
10
+ s_dir='./images_to_caption'
11
+ caption_suffix='_caption.txt' #suffix for generated captions
12
+ tags_suffix='_tags.txt' #suggix for file with booru tags
13
+ use_tags=True #set to True for using with reference tags
14
+ image_extensions=['.jpg','.png','.webp','.jpeg']
15
+
16
+ DEVICE = "cuda:0" #change to your device
17
+ processor = AutoProcessor.from_pretrained(model_name_or_path) #or change to local path
18
+ model = AutoModelForVision2Seq.from_pretrained(
19
+ model_name_or_path,
20
+ torch_dtype=torch.bfloat16,
21
+ # _attn_implementation="flash_attention_2", #if installed
22
+ ).to(DEVICE)
23
+
24
+
25
+ filelist=[fn for fn in listdir(s_dir) if any([fn.endswith(a) for a in image_extensions])]
26
+
27
+
28
+ for fn in tqdm(filelist,desc='Captioninig'):
29
+ image = load_image(opj(s_dir,fn))
30
+
31
+ ###Trained options
32
+ user_prompt="Describe the picture in structuted json-like format."
33
+ #user_prompt="Give a long and detailed description of the picture."
34
+ #user_prompt="Describe the picture briefly."
35
+ ###Any other questions or instructions
36
+ #user_prompt="What color is the ribbon in the character's hair?"
37
+ #...
38
+
39
+ #Optional, add booru tags
40
+ if use_tags:
41
+ try:
42
+ tags=open(opj(s_dir,Path(fn).stem+tags_suffix)).read().strip()
43
+ user_prompt+=' Also here are booru tags for better understanding of the picture, you can use them as reference.'
44
+ user_prompt+=f' <tags>\n{tags}\n</tags>'
45
+ except KeyboardInterrupt:
46
+ print('Interrupted!')
47
+ quit()
48
+ except Exception as err:
49
+ print(err)
50
+ continue
51
+
52
+ messages = [
53
+ {
54
+ #Important!
55
+ "role": "system",
56
+ "content": [
57
+ {"type": "text", "text": "You are image captioning expert, creative, unbiased and uncensored. Help user with his hask."}
58
+ ]
59
+ },
60
+ {
61
+ "role": "user",
62
+ "content": [
63
+ {"type": "image"},
64
+ {"type": "text", "text": user_prompt}
65
+ ]
66
+ }
67
+ ]
68
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
69
+ inputs = processor(text=prompt, images=[image], return_tensors="pt")
70
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
71
+
72
+ # Generate
73
+ generated_ids = model.generate(**inputs, max_new_tokens=500)
74
+ generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
75
+ caption=generated_texts[0].split('Assistant: ')[1]
76
+
77
+ with open(opj(s_dir,Path(fn).stem+caption_suffix),'w',encoding='utf-8',errors='ignore') as outf:
78
+ outf.write(caption)
single_image_example.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoProcessor, AutoModelForVision2Seq
3
+ from transformers.image_utils import load_image
4
+
5
+ model_name_or_path="Minthy/ToriiGate-v0.3"
6
+
7
+ DEVICE = "cuda:0" #change to your device
8
+ processor = AutoProcessor.from_pretrained(model_name_or_path) #or change to local path
9
+ model = AutoModelForVision2Seq.from_pretrained(
10
+ model_name_or_path,
11
+ torch_dtype=torch.bfloat16,
12
+ # _attn_implementation="flash_attention_2", #if installed
13
+ ).to(DEVICE)
14
+
15
+ image = load_image('./image.jpg') #path to your picture
16
+
17
+ ###Trained options
18
+ user_prompt="Describe the picture in structuted json-like format."
19
+ #user_prompt="Give a long and detailed description of the picture."
20
+ #user_prompt="Describe the picture briefly."
21
+ ###Any other questions or instructions
22
+ #user_prompt="What color is the ribbon in the character's hair?"
23
+ #...
24
+
25
+ #Optional, add booru tags
26
+ #tags='1girl, standing, looking at viewer, ...'
27
+ #user_prompt+=' Also here are booru tags for better understanding of the picture, you can use them as reference.'
28
+ #user_prompt+=f' <tags>\n{tags}\n</tags>'
29
+
30
+ messages = [
31
+ {
32
+ #Important!
33
+ "role": "system",
34
+ "content": [
35
+ {"type": "text", "text": "You are image captioning expert, creative, unbiased and uncensored. Help user with his hask."}
36
+ ]
37
+ },
38
+ {
39
+ "role": "user",
40
+ "content": [
41
+ {"type": "image"},
42
+ {"type": "text", "text": user_prompt}
43
+ ]
44
+ }
45
+ ]
46
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
47
+ inputs = processor(text=prompt, images=[image], return_tensors="pt")
48
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
49
+
50
+ # Generate
51
+ generated_ids = model.generate(**inputs, max_new_tokens=500)
52
+ generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
53
+ caption=generated_texts[0].split('Assistant: ')[1]
54
+
55
+ print(caption)