DDingcheol commited on
Commit
a775df9
ยท
1 Parent(s): 5466d27

Upload app.py.py

Browse files
Files changed (1) hide show
  1. app.py.py +211 -0
app.py.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Untitled35.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1o8BEsLXWGF91Q1MOvzj5ZRaEHgUp-kOM
8
+
9
+ # 0. ํ•„์š”ํ•œ ๋ชจ๋“ˆ ๋‹ค์šด๋กœ๋“œ ๋ฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
10
+ """
11
+
12
+ !pip install datasets
13
+ !pip install huggingface_hub
14
+ !python -c "from huggingface_hub.hf_api import HfFolder; HfFolder.save_token('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')"
15
+ !pip install accelerate>=0.20.1
16
+ !pip install accelerate -U
17
+
18
+ import torch
19
+ from transformers import BertTokenizerFast, BertForQuestionAnswering, Trainer, TrainingArguments
20
+ from datasets import load_dataset
21
+ from collections import defaultdict
22
+
23
+ """# 1. ๋ฐ์ดํ„ฐ ๊ฐ€์ ธ์˜ค๊ธฐ"""
24
+
25
+ dataset_load = load_dataset('Multimodal-Fatima/OK-VQA_train') # Multimodal-Fatima/OK-VQA_train ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
26
+ Dataset = dataset_load['train'].select(range(300)) # ๋ฐ์ดํ„ฐ 200~300๊ฐœ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ -> ์ œ์ž‘์ž๋Š” 300๊ฐœ
27
+
28
+ """### 1-1. ๊ฒฐ๊ณผ ํ™•์ธ"""
29
+
30
+ Dataset
31
+
32
+ """# 2. ๋ถˆํ•„์š”ํ•œ ํŠน์„ฑ ์ œ์™ธ"""
33
+
34
+ selected_features = ['image', 'answers', 'question']
35
+ selected_dataset = Dataset.from_dict({feature: Dataset[feature] for feature in selected_features})
36
+
37
+ """### 2-1. ๊ฒฐ๊ณผ ํ™•์ธ"""
38
+
39
+ selected_dataset
40
+
41
+ """# 3. ์†Œํ”„ํŠธ ์ธ์ฝ”๋”ฉ (๋ผ๋ฒจ ์ธ์ฝ”๋”ฉ)"""
42
+
43
+ # ๊ฐ ๋‹ต๋ณ€๋“ค์„ ๊ณ ์œ ํ•œ ID๋กœ ๋งคํ•‘ํ•˜๊ธฐ ์œ„ํ•œ ๋”•์…”๋„ˆ๋ฆฌ ์ƒ์„ฑ
44
+ answers_to_id = defaultdict(lambda: len(answers_to_id))
45
+ selected_dataset = selected_dataset.map(lambda ex: {'answers': [answers_to_id[ans] for ans in ex['answers']],
46
+ 'question': ex['question'],
47
+ 'image': ex['image']})
48
+
49
+ # id๋ฅผ ๋‹ต๋ณ€๋“ค๋กœ ๋งคํ•‘ํ•˜๋Š” ๋”•์…”๋„ˆ๋ฆฌ ์ƒ์„ฑ
50
+ id_to_answers = {v: k for k, v in answers_to_id.items()}
51
+
52
+ # labels๋กœ์˜ ๋งคํ•‘์„ ์œ„ํ•œ ๋”•์…”๋„ˆ๋ฆฌ ์ƒ์„ฑ
53
+ id_to_labels = {k: ex['answers'] for k, ex in enumerate(selected_dataset)}
54
+
55
+ # ID๋กœ ๋งคํ•‘๋œ 'answers'๋ฅผ labels๋กœ ๋ณ€ํ™˜
56
+ selected_dataset = selected_dataset.map(lambda ex: {'answers': id_to_labels.get(ex['answers'][0]),
57
+ 'question': ex['question'],
58
+ 'image': ex['image']})
59
+ # ํŽธํ‰ํ™”์‹œํ‚ค๊ธฐ
60
+ flattened_features = []
61
+
62
+ # ๊ฐ ๋ฐ์ดํ„ฐ ํŽธํ‰ํ™”ํ•˜์—ฌ flattened_features์— ์ถ”๊ฐ€
63
+ for ex in selected_dataset:
64
+ flattened_example = {
65
+ 'answers': ex['answers'],
66
+ 'question': ex['question'],
67
+ 'image': ex['image'],
68
+ }
69
+ flattened_features.append(flattened_example)
70
+
71
+ """### 3-1. ๊ฒฐ๊ณผ ํ™•์ธ"""
72
+
73
+ selected_dataset
74
+
75
+ """# 4. ๋ชจ๋ธ ๊ฐ€์ ธ์˜ค๊ธฐ"""
76
+
77
+ ##๋ชจ๋ธ ๊ฐ€์ ธ์˜ค๊ธฐ
78
+ from huggingface_hub import notebook_login
79
+ notebook_login('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')
80
+
81
+ # Use a pipeline as a high-level helper
82
+ from transformers import pipeline
83
+ pipe = pipeline("visual-question-answering", model="microsoft/git-base-vqav2")
84
+
85
+ # Load model directly
86
+ from transformers import AutoProcessor, AutoModelForCausalLM
87
+
88
+ processor = AutoProcessor.from_pretrained("microsoft/git-base-vqav2")
89
+ model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vqav2")
90
+ # Push the model to your namespace with the name "my-finetuned-bert".
91
+ model.push_to_hub("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")
92
+
93
+
94
+
95
+ """# 5. ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ"""
96
+
97
+ #BERT ํ† ํฌ๋‚˜์ด์ € ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
98
+ tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-cased')
99
+
100
+ # ๋ฐ์ดํ„ฐ์…‹ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
101
+ ok_vqa_dataset = load_dataset("Multimodal-Fatima/OK-VQA_train")
102
+
103
+ # ์ฒ˜์Œ 300๊ฐœ์˜ ์˜ˆ์ œ๋งŒ ์„ ํƒํ•ฉ๋‹ˆ๋‹ค
104
+ ok_vqa_dataset = ok_vqa_dataset['train'].select(range(300))
105
+
106
+ # ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ ํ•จ์ˆ˜ ์ •์˜
107
+ def preprocess_function(examples):
108
+ # ์งˆ๋ฌธ ํ† ํฐํ™”
109
+ tokenized_inputs = tokenizer(examples['question'], truncation=True, padding=True)
110
+
111
+ # 'pixel_values'์™€ 'pixel_mask'๋ฅผ 300๊ฐœ์˜ ์š”์†Œ๋กœ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค
112
+ examples['pixel_values'] = [(4, 3, 244, 244)] * 300 # ์‹ค์ œ ํ”ฝ์…€ ๊ฐ’์œผ๋กœ ๋Œ€์ฒดํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค
113
+ examples['pixel_mask'] = [1] * 300 # ์‹ค์ œ ํ”ฝ์…€ ๋งˆ์Šคํฌ ๊ฐ’์œผ๋กœ ๋Œ€์ฒดํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค
114
+
115
+ return {
116
+ 'input_ids': tokenized_inputs['input_ids'],
117
+ 'attention_mask': tokenized_inputs['attention_mask'],
118
+ 'pixel_values': examples['pixel_values'],
119
+ 'pixel_mask': examples['pixel_mask'],
120
+ 'labels': [[label] for label in examples['answers'][:300]] # 'answers'๋ฅผ 2์ฐจ์› ๋ฐฐ์—ด๋กœ ํ•œ์ •ํ•ฉ๋‹ˆ๋‹ค
121
+ }
122
+
123
+ # ๋ฐ์ดํ„ฐ์…‹์— ์ „์ฒ˜๋ฆฌ๋ฅผ ์ ์šฉํ•ฉ๋‹ˆ๋‹ค
124
+ ok_vqa_dataset = ok_vqa_dataset.map(preprocess_function, batched=True)
125
+
126
+ # 'ok_vqa_dataset'์˜ features๋ฅผ ์ •๋ฆฌํ•ฉ๋‹ˆ๋‹ค
127
+ ok_vqa_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'pixel_values', 'pixel_mask', 'labels'])
128
+
129
+
130
+ #ok_vqa_dataset์—์„œ ํ•˜๊ธฐ๊ฐ€ ํž˜๋“ค์–ด์„œ ์ƒˆ๋กœ์šด new_ok_vqa_dataset์œผ๋กœ ์ •๋ ฌ
131
+ new_ok_vqa_dataset = Dataset.from_dict({
132
+ 'input_ids': ok_vqa_dataset['input_ids'],
133
+ 'attention_mask': ok_vqa_dataset['attention_mask'],
134
+ 'pixel_values': ok_vqa_dataset['pixel_values'],
135
+ 'pixel_mask': ok_vqa_dataset['pixel_mask'],
136
+ 'labels': ok_vqa_dataset['labels']
137
+ })
138
+
139
+ """### 5-1. ๊ฒฐ๊ณผ ํ™•์ธ"""
140
+
141
+ new_ok_vqa_dataset
142
+
143
+ """# 6. ๋ฐฐ์น˜ ์ƒ์„ฑ ๋ฐ ๋ชจ๋ธ ์ดˆ๊ธฐํ™”"""
144
+
145
+ from transformers import BertForSequenceClassification, BertTokenizer
146
+
147
+ # ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ๋ฐ ๊ฐ€์ค‘์น˜ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
148
+ model_name = 'microsoft/git-base-vqav2' # ์‚ฌ์šฉํ•  ๋ชจ๋ธ์˜ ์ด๋ฆ„
149
+ model = BertForSequenceClassification.from_pretrained(model_name)
150
+
151
+ # ์ถœ๋ ฅ ๋ ˆ์ด๋ธ” ์ˆ˜ ์„ค์ •
152
+ num_labels = len(id_to_labels) # ๋ ˆ์ด๋ธ”์˜ ์ˆ˜๋Š” ID๋กœ๋ถ€ํ„ฐ ์ƒ์„ฑ๋œ labels์˜ ๊ธธ์ด์— ํ•ด๋‹นํ•ฉ๋‹ˆ๋‹ค
153
+ model.config.num_labels = num_labels # ๋ชจ๋ธ ์„ค์ •์—์„œ ์ถœ๋ ฅ ๋ ˆ์ด๋ธ” ์ˆ˜๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค
154
+
155
+ # ๋ ˆ์ด๋ธ”์„ ID๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ํ•จ์ˆ˜
156
+ id_to_labels = {}
157
+
158
+ for k, ex in enumerate(selected_dataset):
159
+ if ex['answers'] is not None and len(ex['answers']) > 0:
160
+ id_to_labels[k] = ex['answers'][0]
161
+
162
+ label_to_id = {v: k for k, v in id_to_labels.items()}
163
+
164
+ # ์˜ˆ์ธก๋œ ID๋ฅผ ๋ ˆ์ด๋ธ”๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ํ•จ์ˆ˜
165
+ def id_to_label_fn(pred_id):
166
+ return id_to_labels[pred_id]
167
+
168
+ # ์‹ค์ œ ๋ ˆ์ด๋ธ”์„ ๋ชจ๋ธ ์ถœ๋ ฅ ํฌ๋งท์— ๋งž๋Š” ID๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ํ•จ์ˆ˜
169
+ def label_to_id_fn(label):
170
+ return label_to_id[label]
171
+
172
+ # ์˜ˆ์ธกํ•  ์ž…๋ ฅ ๋ฌธ์žฅ
173
+ input_text = "Your input text goes here..."
174
+
175
+ # ์ž…๋ ฅ ๋ฌธ์žฅ์„ ํ† ํฌ๋‚˜์ด์ง•ํ•˜์—ฌ ๋ชจ๋ธ์— ์ž…๋ ฅํ•  ํ˜•ํƒœ๋กœ ๋ณ€ํ™˜
176
+ tokenizer = BertTokenizer.from_pretrained(model_name)
177
+ encoded_input = tokenizer(input_text, return_tensors='pt')
178
+
179
+ # ๋ชจ๋ธ์— ์ž…๋ ฅ ๋ฐ์ดํ„ฐ๋ฅผ ์ „๋‹ฌํ•˜์—ฌ ์˜ˆ์ธก ์ˆ˜ํ–‰
180
+ with torch.no_grad():
181
+ outputs = model(**encoded_input)
182
+
183
+ # ์˜ˆ์ธก ๊ฒฐ๊ณผ์—์„œ ๊ฐ€์žฅ ๋†’์€ ํ™•๋ฅ ์„ ๊ฐ€์ง„ ๋ ˆ์ด๋ธ” ID ๊ฐ€์ ธ์˜ค๊ธฐ
184
+ predicted_label_id = torch.argmax(outputs.logits).item()
185
+
186
+ # ์˜ˆ์ธก๋œ ๋ ˆ์ด๋ธ” ID๋ฅผ ๋ ˆ์ด๋ธ”๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ ์ถœ๋ ฅ
187
+ predicted_label = id_to_label_fn(predicted_label_id)
188
+
189
+ """### 6-1. ๊ฒฐ๊ณผ ํ™•์ธ"""
190
+
191
+ print("Predicted Label:", predicted_label)
192
+
193
+ """# 7. Finetuning"""
194
+
195
+ # TrainingArguments ์„ค์ •
196
+ training_args = TrainingArguments(
197
+ output_dir='./results', # ๋ชจ๋ธ ์•„์›ƒํ’‹ ๋””๋ ‰ํ† ๋ฆฌ
198
+ num_train_epochs=20, # ํ•™์Šต ์—ํญ ์ˆ˜
199
+ per_device_train_batch_size=4, # ๋ฐฐ์น˜ ์‚ฌ์ด์ฆˆ
200
+ logging_steps=500, # ๋กœ๊น… ๊ฐ„๊ฒฉ
201
+ )
202
+
203
+ # Trainer ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
204
+ trainer = Trainer(
205
+ model=model, # ํ•™์Šต ๋ชจ๋ธ
206
+ args=training_args, # TrainingArguments
207
+ train_dataset=new_ok_vqa_dataset # ํ•™์Šต ๋ฐ์ดํ„ฐ์…‹
208
+ )
209
+
210
+ """7-1. ๊ฒฐ๊ณผ ํ™•์ธ"""
211
+