File size: 9,705 Bytes
4c1f7f8 fdaaec7 4c1f7f8 fdaaec7 4c1f7f8 fdaaec7 4c1f7f8 fdaaec7 4c1f7f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
---
license: mit
tags:
- generated_from_trainer
datasets:
- funsd-layoutlmv3
model-index:
- name: lilt-en-funsd
results: []
---
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
should probably proofread and complete it, then remove this comment. -->
# lilt-en-funsd
This model is a fine-tuned version of [SCUT-DLVCLab/lilt-roberta-en-base](https://huggingface.co/SCUT-DLVCLab/lilt-roberta-en-base) on the funsd-layoutlmv3 dataset.
It achieves the following results on the evaluation set:
- Loss: 1.6117
- Answer: {'precision': 0.8821428571428571, 'recall': 0.9069767441860465, 'f1': 0.8943874471937237, 'number': 817}
- Header: {'precision': 0.6126126126126126, 'recall': 0.5714285714285714, 'f1': 0.591304347826087, 'number': 119}
- Question: {'precision': 0.9045045045045045, 'recall': 0.9322191272051996, 'f1': 0.9181527206218564, 'number': 1077}
- Overall Precision: 0.8797
- Overall Recall: 0.9006
- Overall F1: 0.8900
- Overall Accuracy: 0.8204
## Model Usage
```python
from transformers import LiltForTokenClassification, LayoutLMv3Processor
from PIL import Image, ImageDraw, ImageFont
import torch
# load model and processor from huggingface hub
model = LiltForTokenClassification.from_pretrained("philschmid/lilt-en-funsd")
processor = LayoutLMv3Processor.from_pretrained("philschmid/lilt-en-funsd")
# helper function to unnormalize bboxes for drawing onto the image
def unnormalize_box(bbox, width, height):
return [
width * (bbox[0] / 1000),
height * (bbox[1] / 1000),
width * (bbox[2] / 1000),
height * (bbox[3] / 1000),
]
label2color = {
"B-HEADER": "blue",
"B-QUESTION": "red",
"B-ANSWER": "green",
"I-HEADER": "blue",
"I-QUESTION": "red",
"I-ANSWER": "green",
}
# draw results onto the image
def draw_boxes(image, boxes, predictions):
width, height = image.size
normalizes_boxes = [unnormalize_box(box, width, height) for box in boxes]
# draw predictions over the image
draw = ImageDraw.Draw(image)
font = ImageFont.load_default()
for prediction, box in zip(predictions, normalizes_boxes):
if prediction == "O":
continue
draw.rectangle(box, outline="black")
draw.rectangle(box, outline=label2color[prediction])
draw.text((box[0] + 10, box[1] - 10), text=prediction, fill=label2color[prediction], font=font)
return image
# run inference
def run_inference(image, model=model, processor=processor, output_image=True):
# create model input
encoding = processor(image, return_tensors="pt")
del encoding["pixel_values"]
# run inference
outputs = model(**encoding)
predictions = outputs.logits.argmax(-1).squeeze().tolist()
# get labels
labels = [model.config.id2label[prediction] for prediction in predictions]
if output_image:
return draw_boxes(image, encoding["bbox"][0], labels)
else:
return labels
run_inference(dataset["test"][34]["image"])
```
## Training procedure
### Training hyperparameters
The following hyperparameters were used during training:
- learning_rate: 5e-05
- train_batch_size: 8
- eval_batch_size: 8
- seed: 42
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
- lr_scheduler_type: linear
- training_steps: 2500
- mixed_precision_training: Native AMP
### Training results
| Training Loss | Epoch | Step | Validation Loss | Answer | Header | Question | Overall Precision | Overall Recall | Overall F1 | Overall Accuracy |
|:-------------:|:------:|:----:|:---------------:|:--------------------------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------------------:|:---------------------------------------------------------------------------------------------------------:|:-----------------:|:--------------:|:----------:|:----------------:|
| 0.0211 | 10.53 | 200 | 1.5528 | {'precision': 0.8458904109589042, 'recall': 0.9069767441860465, 'f1': 0.8753691671588896, 'number': 817} | {'precision': 0.5684210526315789, 'recall': 0.453781512605042, 'f1': 0.5046728971962617, 'number': 119} | {'precision': 0.896551724137931, 'recall': 0.89322191272052, 'f1': 0.8948837209302325, 'number': 1077} | 0.8596 | 0.8728 | 0.8662 | 0.8011 |
| 0.0132 | 21.05 | 400 | 1.3143 | {'precision': 0.8447058823529412, 'recall': 0.8788249694002448, 'f1': 0.8614277144571085, 'number': 817} | {'precision': 0.6020408163265306, 'recall': 0.4957983193277311, 'f1': 0.543778801843318, 'number': 119} | {'precision': 0.8854262144821264, 'recall': 0.8969359331476323, 'f1': 0.8911439114391144, 'number': 1077} | 0.8548 | 0.8659 | 0.8603 | 0.8095 |
| 0.0052 | 31.58 | 600 | 1.5747 | {'precision': 0.8482446206115515, 'recall': 0.9167686658506732, 'f1': 0.8811764705882352, 'number': 817} | {'precision': 0.6283185840707964, 'recall': 0.5966386554621849, 'f1': 0.6120689655172413, 'number': 119} | {'precision': 0.8997161778618732, 'recall': 0.883008356545961, 'f1': 0.8912839737582005, 'number': 1077} | 0.8626 | 0.8798 | 0.8711 | 0.8030 |
| 0.0073 | 42.11 | 800 | 1.4848 | {'precision': 0.8487972508591065, 'recall': 0.9069767441860465, 'f1': 0.8769230769230769, 'number': 817} | {'precision': 0.5190839694656488, 'recall': 0.5714285714285714, 'f1': 0.5439999999999999, 'number': 119} | {'precision': 0.8941947565543071, 'recall': 0.8867223769730733, 'f1': 0.8904428904428905, 'number': 1077} | 0.8514 | 0.8763 | 0.8636 | 0.7969 |
| 0.0057 | 52.63 | 1000 | 1.3993 | {'precision': 0.8852071005917159, 'recall': 0.9155446756425949, 'f1': 0.9001203369434416, 'number': 817} | {'precision': 0.5454545454545454, 'recall': 0.6050420168067226, 'f1': 0.5737051792828685, 'number': 119} | {'precision': 0.899090909090909, 'recall': 0.9182915506035283, 'f1': 0.9085898024804776, 'number': 1077} | 0.8710 | 0.8987 | 0.8846 | 0.8198 |
| 0.0023 | 63.16 | 1200 | 1.6463 | {'precision': 0.8961201501877347, 'recall': 0.8763769889840881, 'f1': 0.886138613861386, 'number': 817} | {'precision': 0.5625, 'recall': 0.5294117647058824, 'f1': 0.5454545454545455, 'number': 119} | {'precision': 0.888, 'recall': 0.9275766016713092, 'f1': 0.9073569482288827, 'number': 1077} | 0.8733 | 0.8833 | 0.8782 | 0.8082 |
| 0.001 | 73.68 | 1400 | 1.6476 | {'precision': 0.8676814988290398, 'recall': 0.9069767441860465, 'f1': 0.8868940754039496, 'number': 817} | {'precision': 0.6571428571428571, 'recall': 0.5798319327731093, 'f1': 0.6160714285714286, 'number': 119} | {'precision': 0.908256880733945, 'recall': 0.9192200557103064, 'f1': 0.9137055837563451, 'number': 1077} | 0.8785 | 0.8942 | 0.8863 | 0.8137 |
| 0.0014 | 84.21 | 1600 | 1.6493 | {'precision': 0.8814814814814815, 'recall': 0.8739290085679314, 'f1': 0.8776889981561156, 'number': 817} | {'precision': 0.6194690265486725, 'recall': 0.5882352941176471, 'f1': 0.603448275862069, 'number': 119} | {'precision': 0.894404332129964, 'recall': 0.9201485608170845, 'f1': 0.9070938215102976, 'number': 1077} | 0.8740 | 0.8818 | 0.8778 | 0.8041 |
| 0.0006 | 94.74 | 1800 | 1.6193 | {'precision': 0.8766467065868263, 'recall': 0.8959608323133414, 'f1': 0.8861985472154963, 'number': 817} | {'precision': 0.6068376068376068, 'recall': 0.5966386554621849, 'f1': 0.6016949152542374, 'number': 119} | {'precision': 0.8946428571428572, 'recall': 0.9303621169916435, 'f1': 0.912152935821575, 'number': 1077} | 0.8711 | 0.8967 | 0.8837 | 0.8137 |
| 0.0001 | 105.26 | 2000 | 1.6048 | {'precision': 0.8751472320376914, 'recall': 0.9094247246022031, 'f1': 0.8919567827130852, 'number': 817} | {'precision': 0.6140350877192983, 'recall': 0.5882352941176471, 'f1': 0.6008583690987125, 'number': 119} | {'precision': 0.9062784349408554, 'recall': 0.924791086350975, 'f1': 0.9154411764705882, 'number': 1077} | 0.8773 | 0.8987 | 0.8879 | 0.8194 |
| 0.0001 | 115.79 | 2200 | 1.6117 | {'precision': 0.8821428571428571, 'recall': 0.9069767441860465, 'f1': 0.8943874471937237, 'number': 817} | {'precision': 0.6126126126126126, 'recall': 0.5714285714285714, 'f1': 0.591304347826087, 'number': 119} | {'precision': 0.9045045045045045, 'recall': 0.9322191272051996, 'f1': 0.9181527206218564, 'number': 1077} | 0.8797 | 0.9006 | 0.8900 | 0.8204 |
| 0.0001 | 126.32 | 2400 | 1.6163 | {'precision': 0.8799048751486326, 'recall': 0.9057527539779682, 'f1': 0.8926417370325694, 'number': 817} | {'precision': 0.6052631578947368, 'recall': 0.5798319327731093, 'f1': 0.5922746781115881, 'number': 119} | {'precision': 0.9062784349408554, 'recall': 0.924791086350975, 'f1': 0.9154411764705882, 'number': 1077} | 0.8788 | 0.8967 | 0.8876 | 0.8192 |
### Framework versions
- Transformers 4.24.0
- Pytorch 1.12.1+cu113
- Datasets 2.7.0
- Tokenizers 0.12.1
|