Commit
·
a91989d
1
Parent(s):
2cf7575
Delete create_align_dataset.py
Browse files- create_align_dataset.py +0 -134
create_align_dataset.py
DELETED
@@ -1,134 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import os
|
3 |
-
import json
|
4 |
-
from tqdm import tqdm
|
5 |
-
import random
|
6 |
-
import numpy as np
|
7 |
-
from PIL import Image
|
8 |
-
import webdataset as wds
|
9 |
-
import torch
|
10 |
-
from torchvision.datasets import ImageFolder
|
11 |
-
import torchvision.transforms as transforms
|
12 |
-
|
13 |
-
import openai
|
14 |
-
from tenacity import (
|
15 |
-
retry,
|
16 |
-
stop_after_attempt,
|
17 |
-
wait_random_exponential,
|
18 |
-
) # for exponential backoff
|
19 |
-
|
20 |
-
from minigpt4.common.config import Config
|
21 |
-
from minigpt4.common.registry import registry
|
22 |
-
from minigpt4.conversation.conversation import Chat
|
23 |
-
|
24 |
-
openai.api_key = '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'
|
25 |
-
|
26 |
-
|
27 |
-
def prepare_chatgpt_message(task_prompt, paragraph):
|
28 |
-
messages = [{"role": "system", "content": task_prompt},
|
29 |
-
{"role": "user", "content": paragraph}]
|
30 |
-
return messages
|
31 |
-
|
32 |
-
|
33 |
-
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
34 |
-
def call_chatgpt(chatgpt_messages, max_tokens=200, model="gpt-3.5-turbo"):
|
35 |
-
response = openai.ChatCompletion.create(model=model, messages=chatgpt_messages, temperature=0.7, max_tokens=max_tokens)
|
36 |
-
reply = response['choices'][0]['message']['content']
|
37 |
-
total_tokens = response['usage']['total_tokens']
|
38 |
-
return reply, total_tokens
|
39 |
-
|
40 |
-
|
41 |
-
def main(args):
|
42 |
-
|
43 |
-
print('Initializing Chat')
|
44 |
-
cfg = Config(args)
|
45 |
-
|
46 |
-
model_config = cfg.model_cfg
|
47 |
-
model_cls = registry.get_model_class(model_config.arch)
|
48 |
-
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.device))
|
49 |
-
|
50 |
-
ckpt_path = '/ibex/project/c2133/vicuna_ckpt_test/Vicuna_pretrain_stage2_cc/20230405233_3GPU40kSTEP_MAIN/checkpoint_3.pth'
|
51 |
-
ckpt = torch.load(ckpt_path)
|
52 |
-
msg = model.load_state_dict(ckpt['model'], strict=False)
|
53 |
-
|
54 |
-
|
55 |
-
vis_processor_cfg = cfg.datasets_cfg.cc_combine.vis_processor.train
|
56 |
-
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
57 |
-
|
58 |
-
text_processor_cfg = cfg.datasets_cfg.laion.text_processor.train
|
59 |
-
text_processor = registry.get_processor_class(text_processor_cfg.name).from_config(text_processor_cfg)
|
60 |
-
|
61 |
-
chat = Chat(model, vis_processor, args.device)
|
62 |
-
print('Initialization Finished')
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
texts = {}
|
67 |
-
negative_list = []
|
68 |
-
|
69 |
-
for i in tqdm(range(args.begin_id, args.end_id)):
|
70 |
-
image = Image.open(os.path.join(args.save_dir, 'image/{}.jpg'.format(i))).convert('RGB')
|
71 |
-
|
72 |
-
fix_prompt = \
|
73 |
-
"Fix the error in the given paragraph. " \
|
74 |
-
"Remove any repeating sentences, meanless characters, not English sentences, and so on." \
|
75 |
-
"Remove unnecessary repetition." \
|
76 |
-
"Rewrite any incomplete sentences." \
|
77 |
-
"Return directly the results WITHOUT explanation." \
|
78 |
-
"Return directly the input paragraph if it is already correct WITHOUT explanation."
|
79 |
-
|
80 |
-
answers = []
|
81 |
-
answer_tokens = 0
|
82 |
-
chat.reset()
|
83 |
-
chat.upload_img(image)
|
84 |
-
chat.ask("Describe this image in detail. Give as many details as possible. Say everything you see.")
|
85 |
-
answer, tokens = chat.answer()
|
86 |
-
answers.append(answer)
|
87 |
-
answer_tokens += tokens
|
88 |
-
if len(answer_tokens) < 80:
|
89 |
-
chat.ask("Continue")
|
90 |
-
answer, answer_token = chat.answer()
|
91 |
-
answers.append(answer)
|
92 |
-
answer_tokens += tokens
|
93 |
-
answer = ' '.join(answers)
|
94 |
-
|
95 |
-
chatgpt_message = prepare_chatgpt_message(fix_prompt, answer)
|
96 |
-
improved_answer, num_token = call_chatgpt(chatgpt_message)
|
97 |
-
|
98 |
-
if 'already correct' in improved_answer:
|
99 |
-
if 'repetition' in improved_answer:
|
100 |
-
continue
|
101 |
-
improved_answer = answer
|
102 |
-
if 'incomplete' in improved_answer or len(improved_answer) < 50:
|
103 |
-
negative_list.append(improved_answer)
|
104 |
-
else:
|
105 |
-
texts[i] = improved_answer
|
106 |
-
|
107 |
-
with open(os.path.join(args.save_dir, "cap_{}_{}.json".format(args.begin_id, args.end_id)), "w") as outfile:
|
108 |
-
# write the dictionary to the file in JSON format
|
109 |
-
json.dump(texts, outfile)
|
110 |
-
|
111 |
-
|
112 |
-
if __name__ == "__main__":
|
113 |
-
parser = argparse.ArgumentParser(description="Create Alignment")
|
114 |
-
|
115 |
-
parser.add_argument("--cfg-path", default='train_config/minigpt4_stage2_align.yaml')
|
116 |
-
parser.add_argument("--save-dir", default="/ibex/project/c2133/blip_dataset/image_alignment")
|
117 |
-
parser.add_argument("--begin-id", type=int)
|
118 |
-
parser.add_argument("--end-id", type=int)
|
119 |
-
parser.add_argument("--device", type=int)
|
120 |
-
parser.add_argument(
|
121 |
-
"--options",
|
122 |
-
nargs="+",
|
123 |
-
help="override some settings in the used config, the key-value pair "
|
124 |
-
"in xxx=yyy format will be merged into config file (deprecate), "
|
125 |
-
"change to --cfg-options instead.",
|
126 |
-
)
|
127 |
-
|
128 |
-
args = parser.parse_args()
|
129 |
-
|
130 |
-
print("begin_id: ", args.begin_id)
|
131 |
-
print("end_id: ", args.end_id)
|
132 |
-
print("device:", args.device)
|
133 |
-
|
134 |
-
main(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|