Vision-CAIR commited on
Commit
a91989d
·
1 Parent(s): 2cf7575

Delete create_align_dataset.py

Browse files
Files changed (1) hide show
  1. create_align_dataset.py +0 -134
create_align_dataset.py DELETED
@@ -1,134 +0,0 @@
1
- import argparse
2
- import os
3
- import json
4
- from tqdm import tqdm
5
- import random
6
- import numpy as np
7
- from PIL import Image
8
- import webdataset as wds
9
- import torch
10
- from torchvision.datasets import ImageFolder
11
- import torchvision.transforms as transforms
12
-
13
- import openai
14
- from tenacity import (
15
- retry,
16
- stop_after_attempt,
17
- wait_random_exponential,
18
- ) # for exponential backoff
19
-
20
- from minigpt4.common.config import Config
21
- from minigpt4.common.registry import registry
22
- from minigpt4.conversation.conversation import Chat
23
-
24
- openai.api_key = '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'
25
-
26
-
27
- def prepare_chatgpt_message(task_prompt, paragraph):
28
- messages = [{"role": "system", "content": task_prompt},
29
- {"role": "user", "content": paragraph}]
30
- return messages
31
-
32
-
33
- @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
34
- def call_chatgpt(chatgpt_messages, max_tokens=200, model="gpt-3.5-turbo"):
35
- response = openai.ChatCompletion.create(model=model, messages=chatgpt_messages, temperature=0.7, max_tokens=max_tokens)
36
- reply = response['choices'][0]['message']['content']
37
- total_tokens = response['usage']['total_tokens']
38
- return reply, total_tokens
39
-
40
-
41
- def main(args):
42
-
43
- print('Initializing Chat')
44
- cfg = Config(args)
45
-
46
- model_config = cfg.model_cfg
47
- model_cls = registry.get_model_class(model_config.arch)
48
- model = model_cls.from_config(model_config).to('cuda:{}'.format(args.device))
49
-
50
- ckpt_path = '/ibex/project/c2133/vicuna_ckpt_test/Vicuna_pretrain_stage2_cc/20230405233_3GPU40kSTEP_MAIN/checkpoint_3.pth'
51
- ckpt = torch.load(ckpt_path)
52
- msg = model.load_state_dict(ckpt['model'], strict=False)
53
-
54
-
55
- vis_processor_cfg = cfg.datasets_cfg.cc_combine.vis_processor.train
56
- vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
57
-
58
- text_processor_cfg = cfg.datasets_cfg.laion.text_processor.train
59
- text_processor = registry.get_processor_class(text_processor_cfg.name).from_config(text_processor_cfg)
60
-
61
- chat = Chat(model, vis_processor, args.device)
62
- print('Initialization Finished')
63
-
64
-
65
-
66
- texts = {}
67
- negative_list = []
68
-
69
- for i in tqdm(range(args.begin_id, args.end_id)):
70
- image = Image.open(os.path.join(args.save_dir, 'image/{}.jpg'.format(i))).convert('RGB')
71
-
72
- fix_prompt = \
73
- "Fix the error in the given paragraph. " \
74
- "Remove any repeating sentences, meanless characters, not English sentences, and so on." \
75
- "Remove unnecessary repetition." \
76
- "Rewrite any incomplete sentences." \
77
- "Return directly the results WITHOUT explanation." \
78
- "Return directly the input paragraph if it is already correct WITHOUT explanation."
79
-
80
- answers = []
81
- answer_tokens = 0
82
- chat.reset()
83
- chat.upload_img(image)
84
- chat.ask("Describe this image in detail. Give as many details as possible. Say everything you see.")
85
- answer, tokens = chat.answer()
86
- answers.append(answer)
87
- answer_tokens += tokens
88
- if len(answer_tokens) < 80:
89
- chat.ask("Continue")
90
- answer, answer_token = chat.answer()
91
- answers.append(answer)
92
- answer_tokens += tokens
93
- answer = ' '.join(answers)
94
-
95
- chatgpt_message = prepare_chatgpt_message(fix_prompt, answer)
96
- improved_answer, num_token = call_chatgpt(chatgpt_message)
97
-
98
- if 'already correct' in improved_answer:
99
- if 'repetition' in improved_answer:
100
- continue
101
- improved_answer = answer
102
- if 'incomplete' in improved_answer or len(improved_answer) < 50:
103
- negative_list.append(improved_answer)
104
- else:
105
- texts[i] = improved_answer
106
-
107
- with open(os.path.join(args.save_dir, "cap_{}_{}.json".format(args.begin_id, args.end_id)), "w") as outfile:
108
- # write the dictionary to the file in JSON format
109
- json.dump(texts, outfile)
110
-
111
-
112
- if __name__ == "__main__":
113
- parser = argparse.ArgumentParser(description="Create Alignment")
114
-
115
- parser.add_argument("--cfg-path", default='train_config/minigpt4_stage2_align.yaml')
116
- parser.add_argument("--save-dir", default="/ibex/project/c2133/blip_dataset/image_alignment")
117
- parser.add_argument("--begin-id", type=int)
118
- parser.add_argument("--end-id", type=int)
119
- parser.add_argument("--device", type=int)
120
- parser.add_argument(
121
- "--options",
122
- nargs="+",
123
- help="override some settings in the used config, the key-value pair "
124
- "in xxx=yyy format will be merged into config file (deprecate), "
125
- "change to --cfg-options instead.",
126
- )
127
-
128
- args = parser.parse_args()
129
-
130
- print("begin_id: ", args.begin_id)
131
- print("end_id: ", args.end_id)
132
- print("device:", args.device)
133
-
134
- main(args)