kryman27 commited on
Commit
5c7e49b
·
verified ·
1 Parent(s): bf3bfc2

Create train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +29 -0
train_model.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import LayoutLMForTokenClassification, Trainer, TrainingArguments
2
+ from datasets import load_dataset
3
+
4
+ # Wczytanie przygotowanego zbioru danych
5
+ dataset = load_dataset("json", data_files="training_data.json")
6
+
7
+ # Ładowanie modelu LayoutLM do dostrajania
8
+ model = LayoutLMForTokenClassification.from_pretrained("microsoft/layoutlmv3-base", num_labels=5)
9
+
10
+ training_args = TrainingArguments(
11
+ output_dir="./layoutlmv3_finetuned",
12
+ per_device_train_batch_size=4,
13
+ per_device_eval_batch_size=4,
14
+ num_train_epochs=5,
15
+ evaluation_strategy="epoch",
16
+ save_strategy="epoch"
17
+ )
18
+
19
+ trainer = Trainer(
20
+ model=model,
21
+ args=training_args,
22
+ train_dataset=dataset["train"],
23
+ eval_dataset=dataset["test"]
24
+ )
25
+
26
+ trainer.train()
27
+
28
+ # Zapisanie modelu
29
+ model.save_pretrained("./layoutlmv3_finetuned")