Training in progress, step 1000
Browse files- adapter_config.json +3 -3
- adapter_model.safetensors +1 -1
- config.json +1 -2
- finetune_phi3_vision.py +151 -120
- idefics2/adapter_model.safetensors +1 -1
- idefics2/training_args.bin +1 -1
- inference.py +70 -10
- model.safetensors +1 -1
- trainer_lora.py +8 -5
- training_args.bin +1 -1
adapter_config.json
CHANGED
@@ -23,11 +23,11 @@
|
|
23 |
"rank_pattern": {},
|
24 |
"revision": null,
|
25 |
"target_modules": [
|
26 |
-
"
|
|
|
27 |
"value",
|
28 |
"key",
|
29 |
-
"
|
30 |
-
"intermediate.dense"
|
31 |
],
|
32 |
"task_type": null,
|
33 |
"use_dora": false,
|
|
|
23 |
"rank_pattern": {},
|
24 |
"revision": null,
|
25 |
"target_modules": [
|
26 |
+
"output.dense",
|
27 |
+
"intermediate.dense",
|
28 |
"value",
|
29 |
"key",
|
30 |
+
"query"
|
|
|
31 |
],
|
32 |
"task_type": null,
|
33 |
"use_dora": false,
|
adapter_model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 10637752
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ace69a40479cb7818005942179ac01520eee1b042f570b77c58aa179b9f20d6d
|
3 |
size 10637752
|
config.json
CHANGED
@@ -173,6 +173,5 @@
|
|
173 |
"processor_class": "TrOCRProcessor",
|
174 |
"tie_word_embeddings": false,
|
175 |
"torch_dtype": "float32",
|
176 |
-
"transformers_version": "4.
|
177 |
-
"vocab_size": 50265
|
178 |
}
|
|
|
173 |
"processor_class": "TrOCRProcessor",
|
174 |
"tie_word_embeddings": false,
|
175 |
"torch_dtype": "float32",
|
176 |
+
"transformers_version": "4.44.2"
|
|
|
177 |
}
|
finetune_phi3_vision.py
CHANGED
@@ -8,10 +8,59 @@ from transformers import AutoProcessor, BitsAndBytesConfig
|
|
8 |
from transformers import AutoModelForCausalLM, AutoModelForVision2Seq
|
9 |
from datetime import datetime
|
10 |
import evaluate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
# Define train and test size.
|
12 |
TRAIN_SAMPLES = 1000
|
13 |
TEST_SAMPLES = 200
|
14 |
TEST_SIZE = 0.166 #
|
|
|
15 |
|
16 |
# Define the directory containing the images.
|
17 |
df_path = "/mnt/data1/Datasets/AlphaPen/" + "training_data.csv"
|
@@ -19,9 +68,11 @@ df = pd.read_csv(df_path)
|
|
19 |
df.dropna(inplace=True)
|
20 |
df["id"] = range(df.shape[0])
|
21 |
df["query"] = "What is shown in this image?"
|
|
|
22 |
|
23 |
root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/final_cropped_rotated_"
|
24 |
-
|
|
|
25 |
|
26 |
# New batch
|
27 |
df_path_2 = "/mnt/data1/Datasets/AlphaPen/" + "training_b2.csv"
|
@@ -29,38 +80,44 @@ df_2 = pd.read_csv(df_path_2)
|
|
29 |
df_2.dropna(inplace=True)
|
30 |
df_2["id"] = range(df_2.shape[0])
|
31 |
df_2["query"] = "What is shown in this image?"
|
|
|
32 |
|
33 |
root_dir_2 = "/mnt/data1/Datasets/OCR/Alphapen/DataBatch2/clean_data/cropped_data/cropped_"
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
39 |
|
40 |
# Create the dataset dictionary.
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
'
|
45 |
-
'
|
|
|
|
|
46 |
}
|
47 |
|
48 |
# Create the dataset.
|
49 |
-
|
|
|
50 |
|
51 |
# Cast the 'image' column to Image type.
|
52 |
-
|
|
|
53 |
|
54 |
# Split the dataset into train and test.
|
55 |
-
split_dataset = dataset.train_test_split(test_size=TEST_SIZE, shuffle=False)
|
56 |
|
57 |
-
train_dataset = split_dataset["train"]
|
58 |
-
eval_dataset = split_dataset["test"]
|
59 |
-
print(len(
|
60 |
# Push the dataset on Hugging Face Hub.
|
61 |
# split_dataset.push_to_hub("NSTiwari/DocumentIDEFICS_QA")
|
62 |
|
63 |
-
os.environ["WANDB_PROJECT"]="Alphapen"
|
64 |
|
65 |
# Define model ID
|
66 |
# model_id = "microsoft/Phi-3-vision-128k-instruct"
|
@@ -121,112 +178,86 @@ else:
|
|
121 |
|
122 |
|
123 |
|
124 |
-
import random
|
125 |
-
|
126 |
-
class MyDataCollator:
|
127 |
-
def __init__(self, processor):
|
128 |
-
self.processor = processor
|
129 |
-
self.image_token_id = processor.tokenizer.additional_special_tokens_ids[
|
130 |
-
processor.tokenizer.additional_special_tokens.index("<image>")
|
131 |
-
]
|
132 |
-
|
133 |
-
def __call__(self, examples):
|
134 |
-
texts = []
|
135 |
-
images = []
|
136 |
-
for example in examples:
|
137 |
-
image = example["image"]
|
138 |
-
# print(example["query"])
|
139 |
-
question = example["query"]
|
140 |
-
answer = example["answers"]
|
141 |
-
messages = [
|
142 |
-
{
|
143 |
-
"role": "user",
|
144 |
-
"content": [
|
145 |
-
{"type": "text", "text": "OCR the text in the image."},
|
146 |
-
{"type": "image"},
|
147 |
-
{"type": "text", "text": question}
|
148 |
-
]
|
149 |
-
},
|
150 |
-
{
|
151 |
-
"role": "assistant",
|
152 |
-
"content": [
|
153 |
-
{"type": "text", "text": answer}
|
154 |
-
]
|
155 |
-
}
|
156 |
-
]
|
157 |
-
text = processor.apply_chat_template(messages, add_generation_prompt=False)
|
158 |
-
texts.append(text.strip())
|
159 |
-
images.append([image])
|
160 |
-
|
161 |
-
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
162 |
-
|
163 |
-
labels = batch["input_ids"].clone()
|
164 |
-
# labels[labels == processor.tokenizer.pad_token_id] = self.image_token_id
|
165 |
-
batch["labels"] = labels
|
166 |
|
167 |
-
return batch
|
168 |
|
169 |
data_collator = MyDataCollator(processor)
|
170 |
|
171 |
-
from transformers import TrainingArguments, Trainer, Seq2SeqTrainer, Seq2SeqTrainingArguments
|
172 |
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
|
199 |
-
def compute_metrics(pred):
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
trainer = Seq2SeqTrainer(
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
)
|
231 |
|
232 |
-
trainer.train()
|
|
|
8 |
from transformers import AutoModelForCausalLM, AutoModelForVision2Seq
|
9 |
from datetime import datetime
|
10 |
import evaluate
|
11 |
+
from transformers import TrainingArguments, Trainer, Seq2SeqTrainer, Seq2SeqTrainingArguments
|
12 |
+
from sklearn.model_selection import train_test_split
|
13 |
+
|
14 |
+
import random
|
15 |
+
|
16 |
+
class MyDataCollator:
|
17 |
+
def __init__(self, processor):
|
18 |
+
self.processor = processor
|
19 |
+
self.image_token_id = processor.tokenizer.additional_special_tokens_ids[
|
20 |
+
processor.tokenizer.additional_special_tokens.index("<image>")
|
21 |
+
]
|
22 |
+
|
23 |
+
def __call__(self, examples):
|
24 |
+
texts = []
|
25 |
+
images = []
|
26 |
+
for example in examples:
|
27 |
+
image = example["image"]
|
28 |
+
# print(example["query"])
|
29 |
+
question = example["query"]
|
30 |
+
answer = example["answers"]
|
31 |
+
messages = [
|
32 |
+
{
|
33 |
+
"role": "user",
|
34 |
+
"content": [
|
35 |
+
{"type": "text", "text": "OCR the text in the image."},
|
36 |
+
{"type": "image"},
|
37 |
+
{"type": "text", "text": question}
|
38 |
+
]
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"role": "assistant",
|
42 |
+
"content": [
|
43 |
+
{"type": "text", "text": answer}
|
44 |
+
]
|
45 |
+
}
|
46 |
+
]
|
47 |
+
text = processor.apply_chat_template(messages, add_generation_prompt=False)
|
48 |
+
texts.append(text.strip())
|
49 |
+
images.append([image])
|
50 |
+
|
51 |
+
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
52 |
+
|
53 |
+
labels = batch["input_ids"].clone()
|
54 |
+
# labels[labels == processor.tokenizer.pad_token_id] = self.image_token_id
|
55 |
+
batch["labels"] = labels
|
56 |
+
|
57 |
+
return batch
|
58 |
+
|
59 |
# Define train and test size.
|
60 |
TRAIN_SAMPLES = 1000
|
61 |
TEST_SAMPLES = 200
|
62 |
TEST_SIZE = 0.166 #
|
63 |
+
samp_list = [1, 15000, 30000, 45000, 60000, 70000]
|
64 |
|
65 |
# Define the directory containing the images.
|
66 |
df_path = "/mnt/data1/Datasets/AlphaPen/" + "training_data.csv"
|
|
|
68 |
df.dropna(inplace=True)
|
69 |
df["id"] = range(df.shape[0])
|
70 |
df["query"] = "What is shown in this image?"
|
71 |
+
train_df, test_df = train_test_split(df, test_size=0.02, random_state=0)
|
72 |
|
73 |
root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/final_cropped_rotated_"
|
74 |
+
image_paths_train = [root_dir + img for img in train_df.filename]
|
75 |
+
image_paths_test = [root_dir + img for img in test_df.filename]
|
76 |
|
77 |
# New batch
|
78 |
df_path_2 = "/mnt/data1/Datasets/AlphaPen/" + "training_b2.csv"
|
|
|
80 |
df_2.dropna(inplace=True)
|
81 |
df_2["id"] = range(df_2.shape[0])
|
82 |
df_2["query"] = "What is shown in this image?"
|
83 |
+
train_df_b2, test_df_b2 = train_test_split(df_2, test_size=0.01, random_state=0)
|
84 |
|
85 |
root_dir_2 = "/mnt/data1/Datasets/OCR/Alphapen/DataBatch2/clean_data/cropped_data/cropped_"
|
86 |
+
image_paths_2_train = [root_dir_2 + img for img in train_df_b2.filename]
|
87 |
+
image_paths_2_test = [root_dir_2 + img for img in test_df_b2.filename]
|
88 |
+
|
89 |
+
|
90 |
+
ids_test = range(test_df.shape[0] + test_df_b2.shape[0])
|
91 |
+
queries_test = test_df['query'].tolist() + test_df_b2['query'].tolist()
|
92 |
+
answers_test = test_df['text'].tolist() + test_df_b2['text'].tolist()
|
93 |
|
94 |
# Create the dataset dictionary.
|
95 |
+
|
96 |
+
|
97 |
+
eval_dataset_dict = {
|
98 |
+
'id': ids_test,
|
99 |
+
'image': image_paths_test + image_paths_2_test,
|
100 |
+
'query': queries_test,
|
101 |
+
'answers': answers_test
|
102 |
}
|
103 |
|
104 |
# Create the dataset.
|
105 |
+
|
106 |
+
eval_dataset = Dataset.from_dict(eval_dataset_dict)
|
107 |
|
108 |
# Cast the 'image' column to Image type.
|
109 |
+
|
110 |
+
eval_dataset = eval_dataset.cast_column("image", Image())
|
111 |
|
112 |
# Split the dataset into train and test.
|
113 |
+
# split_dataset = dataset.train_test_split(test_size=TEST_SIZE, shuffle=False)
|
114 |
|
115 |
+
# train_dataset = split_dataset["train"]
|
116 |
+
# eval_dataset = split_dataset["test"]
|
117 |
+
print(len(eval_dataset))
|
118 |
# Push the dataset on Hugging Face Hub.
|
119 |
# split_dataset.push_to_hub("NSTiwari/DocumentIDEFICS_QA")
|
120 |
|
|
|
121 |
|
122 |
# Define model ID
|
123 |
# model_id = "microsoft/Phi-3-vision-128k-instruct"
|
|
|
178 |
|
179 |
|
180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
|
|
182 |
|
183 |
data_collator = MyDataCollator(processor)
|
184 |
|
|
|
185 |
|
186 |
+
for samp in samp_list:
|
187 |
+
os.environ["WANDB_PROJECT"]="Alphapen"
|
188 |
+
# Create a list of other columns such as id, query, and answer.
|
189 |
+
ids_train = range(train_df.shape[0] + train_df_b2.shape[0])
|
190 |
+
queries_train = train_df['query'].tolist() + train_df_b2['query'].tolist()
|
191 |
+
answers_train = train_df['text'].tolist() + train_df_b2['text'].tolist()
|
192 |
+
|
193 |
+
train_dataset_dict = {
|
194 |
+
'id': ids_train,
|
195 |
+
'image': image_paths_train + image_paths_2_train,
|
196 |
+
'query': queries_train,
|
197 |
+
'answers': answers_train
|
198 |
+
}
|
199 |
+
|
200 |
+
train_dataset = Dataset.from_dict(train_dataset_dict)
|
201 |
+
train_dataset = train_dataset.cast_column("image", Image())
|
202 |
+
|
203 |
+
training_args = Seq2SeqTrainingArguments(
|
204 |
+
predict_with_generate=True,
|
205 |
+
output_dir = "idefics2",
|
206 |
+
learning_rate = 2e-4,
|
207 |
+
fp16 = True,
|
208 |
+
per_device_train_batch_size = 8,
|
209 |
+
per_device_eval_batch_size = 8,
|
210 |
+
gradient_accumulation_steps = 2,
|
211 |
+
dataloader_pin_memory = False,
|
212 |
+
save_total_limit = 3,
|
213 |
+
eval_strategy ="steps",
|
214 |
+
save_strategy = "steps",
|
215 |
+
eval_steps = 500,
|
216 |
+
save_steps = 1000,
|
217 |
+
max_steps = 5000,
|
218 |
+
logging_steps = 10,
|
219 |
+
remove_unused_columns = False,
|
220 |
+
push_to_hub=True,
|
221 |
+
label_names = ["labels"],
|
222 |
+
load_best_model_at_end = False,
|
223 |
+
report_to = "wandb",
|
224 |
+
optim = "paged_adamw_8bit",
|
225 |
+
# run_name=f"idefics2-vision-LoRA-{datetime.now().strftime('%Y-%m-%d-%H-%M-%s')}",
|
226 |
+
run_name="idefics2-vision-LoRA-" + str(samp),
|
227 |
+
hub_model_id="hadrakey/alphapen_idefics2_" + str(samp),
|
228 |
+
)
|
229 |
|
230 |
+
def compute_metrics(pred):
|
231 |
+
# accuracy_metric = evaluate.load("precision")
|
232 |
+
cer_metric = evaluate.load("cer")
|
233 |
+
|
234 |
+
labels_ids = pred.label_ids
|
235 |
+
pred_ids = pred.predictions
|
236 |
+
# print(pred_ids)
|
237 |
+
# print(labels_ids)
|
238 |
+
# max_length = max(pred_ids.shape[1], labels_ids.shape[1])
|
239 |
+
# generated_texts = processor.batch_decode(generated_ids[:, inputs["input_ids"].size(1):], skip_special_tokens=True)
|
240 |
+
pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
|
241 |
+
pred_str = [word.lower() for word in pred_str]
|
242 |
+
# print(pred_str)
|
243 |
+
# pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
|
244 |
+
labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
|
245 |
+
label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)
|
246 |
+
label_str = [word.lower() for word in label_str]
|
247 |
+
# print(label_str)
|
248 |
+
cer = cer_metric.compute(predictions=pred_str, references=label_str)
|
249 |
+
# accuracy = accuracy_metric.compute(predictions=pred_ids.tolist(), references=labels_ids.tolist())
|
250 |
+
|
251 |
+
return {"cer": cer}
|
252 |
+
|
253 |
+
|
254 |
+
trainer = Seq2SeqTrainer(
|
255 |
+
model = model,
|
256 |
+
args = training_args,
|
257 |
+
data_collator = data_collator,
|
258 |
+
train_dataset = train_dataset,
|
259 |
+
eval_dataset = eval_dataset,
|
260 |
+
compute_metrics=compute_metrics,
|
261 |
+
)
|
262 |
|
263 |
+
trainer.train()
|
idefics2/adapter_model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 746528304
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7e38855b7b26c79a86d6bc42985348143f602714b85923b6fcf6793830f400de
|
3 |
size 746528304
|
idefics2/training_args.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 5368
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c886ec66a448f0680d0a46cd28b697b6899ecc0627e105de6d1eac26f3c78140
|
3 |
size 5368
|
inference.py
CHANGED
@@ -1,9 +1,15 @@
|
|
1 |
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
2 |
import pandas as pd
|
3 |
from PIL import Image
|
|
|
4 |
|
5 |
# Finetuned model
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
#Baseline
|
9 |
model_base = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
|
@@ -15,24 +21,78 @@ df_path = "/mnt/data1/Datasets/AlphaPen/" + "testing_data.csv"
|
|
15 |
data = pd.read_csv(df_path)
|
16 |
data.dropna(inplace=True)
|
17 |
data.reset_index(inplace=True)
|
|
|
18 |
|
19 |
root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/"
|
20 |
|
21 |
inf_baseline = []
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
pixel_values = processor(image, return_tensors="pt").pixel_values
|
27 |
generated_ids_base = model_base.generate(pixel_values)
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
generated_text_base = processor.batch_decode(generated_ids_base, skip_special_tokens=True)[0]
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
inf_baseline.append(generated_text_base)
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
-
|
|
|
1 |
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
2 |
import pandas as pd
|
3 |
from PIL import Image
|
4 |
+
from torchmetrics.text import CharErrorRate
|
5 |
|
6 |
# Finetuned model
|
7 |
+
model_finetune_1 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_1")
|
8 |
+
model_finetune_2 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_15000")
|
9 |
+
model_finetune_3 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_30000")
|
10 |
+
model_finetune_4 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_45000")
|
11 |
+
model_finetune_5 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_60000")
|
12 |
+
model_finetune_6 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_70000")
|
13 |
|
14 |
#Baseline
|
15 |
model_base = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
|
|
|
21 |
data = pd.read_csv(df_path)
|
22 |
data.dropna(inplace=True)
|
23 |
data.reset_index(inplace=True)
|
24 |
+
sample = data.iloc[:50,:]
|
25 |
|
26 |
root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/"
|
27 |
|
28 |
inf_baseline = []
|
29 |
+
inf_finetune_1 = []
|
30 |
+
inf_finetune_2 = []
|
31 |
+
inf_finetune_3 = []
|
32 |
+
inf_finetune_4 = []
|
33 |
+
inf_finetune_5 = []
|
34 |
+
inf_finetune_6 = []
|
35 |
+
|
36 |
+
cer_fine_1 = []
|
37 |
+
cer_fine_2 = []
|
38 |
+
cer_fine_3 = []
|
39 |
+
cer_fine_4 = []
|
40 |
+
cer_fine_5 = []
|
41 |
+
cer_fine_6 = []
|
42 |
+
cer_base = []
|
43 |
+
|
44 |
+
cer_metric = CharErrorRate()
|
45 |
+
|
46 |
+
for idx in range(len(sample)):
|
47 |
+
image = Image.open(root_dir + "final_cropped_rotated_" + sample.filename[idx]).convert("RGB")
|
48 |
|
49 |
pixel_values = processor(image, return_tensors="pt").pixel_values
|
50 |
generated_ids_base = model_base.generate(pixel_values)
|
51 |
+
generated_ids_fine_1 = model_finetune_1.generate(pixel_values)
|
52 |
+
generated_ids_fine_2= model_finetune_2.generate(pixel_values)
|
53 |
+
generated_ids_fine_3 = model_finetune_3.generate(pixel_values)
|
54 |
+
generated_ids_fine_4 = model_finetune_4.generate(pixel_values)
|
55 |
+
generated_ids_fine_5 = model_finetune_5.generate(pixel_values)
|
56 |
+
generated_ids_fine_6 = model_finetune_6.generate(pixel_values)
|
57 |
+
|
58 |
generated_text_base = processor.batch_decode(generated_ids_base, skip_special_tokens=True)[0]
|
59 |
+
generated_text_fine_1= processor.batch_decode(generated_ids_fine_1, skip_special_tokens=True)[0]
|
60 |
+
generated_text_fine_2= processor.batch_decode(generated_ids_fine_2, skip_special_tokens=True)[0]
|
61 |
+
generated_text_fine_3= processor.batch_decode(generated_ids_fine_3, skip_special_tokens=True)[0]
|
62 |
+
generated_text_fine_4= processor.batch_decode(generated_ids_fine_4, skip_special_tokens=True)[0]
|
63 |
+
generated_text_fine_5= processor.batch_decode(generated_ids_fine_5, skip_special_tokens=True)[0]
|
64 |
+
generated_text_fine_6= processor.batch_decode(generated_ids_fine_6, skip_special_tokens=True)[0]
|
65 |
+
|
66 |
+
cer_fine_1.append(cer_metric(generated_text_fine_1.lower(), sample.text[idx].lower()).detach().numpy())
|
67 |
+
cer_fine_2.append(cer_metric(generated_text_fine_2.lower(), sample.text[idx].lower()).detach().numpy())
|
68 |
+
cer_fine_3.append(cer_metric(generated_text_fine_3.lower(), sample.text[idx].lower()).detach().numpy())
|
69 |
+
cer_fine_4.append(cer_metric(generated_text_fine_4.lower(), sample.text[idx].lower()).detach().numpy())
|
70 |
+
cer_fine_5.append(cer_metric(generated_text_fine_5.lower(), sample.text[idx].lower()).detach().numpy())
|
71 |
+
cer_fine_6.append(cer_metric(generated_text_fine_6.lower(), sample.text[idx].lower()).detach().numpy())
|
72 |
+
cer_base.append(cer_metric(generated_text_base.lower(), sample.text[idx].lower()).detach().numpy())
|
73 |
|
74 |
inf_baseline.append(generated_text_base)
|
75 |
+
inf_finetune_1.append(generated_text_fine_1)
|
76 |
+
inf_finetune_2.append(generated_text_fine_2)
|
77 |
+
inf_finetune_3.append(generated_text_fine_3)
|
78 |
+
inf_finetune_4.append(generated_text_fine_4)
|
79 |
+
inf_finetune_5.append(generated_text_fine_5)
|
80 |
+
inf_finetune_6.append(generated_text_fine_6)
|
81 |
+
|
82 |
+
sample["Baseline"]=inf_baseline
|
83 |
+
sample["Finetune_1"]=inf_finetune_1
|
84 |
+
sample["Finetune_2"]=inf_finetune_2
|
85 |
+
sample["Finetune_3"]=inf_finetune_3
|
86 |
+
sample["Finetune_4"]=inf_finetune_4
|
87 |
+
sample["Finetune_5"]=inf_finetune_5
|
88 |
+
sample["Finetune_6"]=inf_finetune_6
|
89 |
|
90 |
+
sample["cer_1"]=cer_fine_1
|
91 |
+
sample["cer_2"]=cer_fine_2
|
92 |
+
sample["cer_3"]=cer_fine_3
|
93 |
+
sample["cer_4"]=cer_fine_4
|
94 |
+
sample["cer_5"]=cer_fine_5
|
95 |
+
sample["cer_6"]=cer_fine_6
|
96 |
+
sample["cer_base"]=cer_base
|
97 |
|
98 |
+
sample.to_csv("/mnt/data1/Datasets/AlphaPen/" + "inference_results.csv")
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1335747032
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b14472ca382e9d96ea7efd3c778cbf0b73a412e31bc41cfec8d97e8988e6063d
|
3 |
size 1335747032
|
trainer_lora.py
CHANGED
@@ -18,8 +18,9 @@ from src.loss import MarginLoss, KLRegularization
|
|
18 |
from src.similarity import CERSimilarity
|
19 |
from datetime import datetime
|
20 |
from torch.utils.data import ConcatDataset
|
|
|
|
|
21 |
|
22 |
-
os.environ["WANDB_PROJECT"] = "Alphapen-TrOCR"
|
23 |
|
24 |
# @dataclass
|
25 |
# class ScriptArguments:
|
@@ -56,10 +57,10 @@ root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/final_cropped_rotated_"
|
|
56 |
root_dir_b2 = "/mnt/data1/Datasets/OCR/Alphapen/DataBatch2/clean_data/cropped_data/cropped_"
|
57 |
processor = TrOCRProcessor.from_pretrained(model_name)
|
58 |
|
59 |
-
train_dataset_b1 = AphaPenDataset(root_dir=root_dir, df=train_df, processor=processor)
|
60 |
-
eval_dataset_b1 = AphaPenDataset(root_dir=root_dir, df=test_df, processor=processor)
|
61 |
|
62 |
-
eval_dataset_b2 = AphaPenDataset(root_dir=root_dir_b2, df=test_df_b2, processor=processor)
|
63 |
|
64 |
# train_dataset = ConcatDataset([train_dataset_b1, train_dataset_b2])
|
65 |
eval_dataset = ConcatDataset([eval_dataset_b1, eval_dataset_b2])
|
@@ -119,6 +120,7 @@ model = get_peft_model(model, lora_config)
|
|
119 |
|
120 |
# from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
121 |
for samp in samp_list:
|
|
|
122 |
train_dataset_b2 = AphaPenDataset(root_dir=root_dir_b2, df=train_df_b2.iloc[:samp,:], processor=processor)
|
123 |
|
124 |
train_dataset = ConcatDataset([train_dataset_b1, train_dataset_b2])
|
@@ -179,4 +181,5 @@ for samp in samp_list:
|
|
179 |
# callbacks=[SavePeftModelCallback]
|
180 |
)
|
181 |
|
182 |
-
trainer.train()
|
|
|
|
18 |
from src.similarity import CERSimilarity
|
19 |
from datetime import datetime
|
20 |
from torch.utils.data import ConcatDataset
|
21 |
+
import wandb
|
22 |
+
|
23 |
|
|
|
24 |
|
25 |
# @dataclass
|
26 |
# class ScriptArguments:
|
|
|
57 |
root_dir_b2 = "/mnt/data1/Datasets/OCR/Alphapen/DataBatch2/clean_data/cropped_data/cropped_"
|
58 |
processor = TrOCRProcessor.from_pretrained(model_name)
|
59 |
|
60 |
+
train_dataset_b1 = AphaPenDataset(root_dir=root_dir, df=train_df.iloc[:100,:], processor=processor)
|
61 |
+
eval_dataset_b1 = AphaPenDataset(root_dir=root_dir, df=test_df.iloc[:100,:], processor=processor)
|
62 |
|
63 |
+
eval_dataset_b2 = AphaPenDataset(root_dir=root_dir_b2, df=test_df_b2.iloc[:100,:], processor=processor)
|
64 |
|
65 |
# train_dataset = ConcatDataset([train_dataset_b1, train_dataset_b2])
|
66 |
eval_dataset = ConcatDataset([eval_dataset_b1, eval_dataset_b2])
|
|
|
120 |
|
121 |
# from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
122 |
for samp in samp_list:
|
123 |
+
os.environ["WANDB_PROJECT"] = "Alphapen-TrOCR"
|
124 |
train_dataset_b2 = AphaPenDataset(root_dir=root_dir_b2, df=train_df_b2.iloc[:samp,:], processor=processor)
|
125 |
|
126 |
train_dataset = ConcatDataset([train_dataset_b1, train_dataset_b2])
|
|
|
181 |
# callbacks=[SavePeftModelCallback]
|
182 |
)
|
183 |
|
184 |
+
trainer.train()
|
185 |
+
wandb.finish()
|
training_args.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 5368
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0dee27664eb5dbd1e3cc935d19708874101ea117d0baf6647d382db3446b7c24
|
3 |
size 5368
|