lilt-en-funsd
This model is a fine-tuned version of SCUT-DLVCLab/lilt-roberta-en-base on the funsd-layoutlmv3 dataset. It achieves the following results on the evaluation set:
- Loss: 1.6117
- Answer: {'precision': 0.8821428571428571, 'recall': 0.9069767441860465, 'f1': 0.8943874471937237, 'number': 817}
- Header: {'precision': 0.6126126126126126, 'recall': 0.5714285714285714, 'f1': 0.591304347826087, 'number': 119}
- Question: {'precision': 0.9045045045045045, 'recall': 0.9322191272051996, 'f1': 0.9181527206218564, 'number': 1077}
- Overall Precision: 0.8797
- Overall Recall: 0.9006
- Overall F1: 0.8900
- Overall Accuracy: 0.8204
Model Usage
from transformers import LiltForTokenClassification, LayoutLMv3Processor
from PIL import Image, ImageDraw, ImageFont
import torch
# load model and processor from huggingface hub
model = LiltForTokenClassification.from_pretrained("philschmid/lilt-en-funsd")
processor = LayoutLMv3Processor.from_pretrained("philschmid/lilt-en-funsd")
# helper function to unnormalize bboxes for drawing onto the image
def unnormalize_box(bbox, width, height):
return [
width * (bbox[0] / 1000),
height * (bbox[1] / 1000),
width * (bbox[2] / 1000),
height * (bbox[3] / 1000),
]
label2color = {
"B-HEADER": "blue",
"B-QUESTION": "red",
"B-ANSWER": "green",
"I-HEADER": "blue",
"I-QUESTION": "red",
"I-ANSWER": "green",
}
# draw results onto the image
def draw_boxes(image, boxes, predictions):
width, height = image.size
normalizes_boxes = [unnormalize_box(box, width, height) for box in boxes]
# draw predictions over the image
draw = ImageDraw.Draw(image)
font = ImageFont.load_default()
for prediction, box in zip(predictions, normalizes_boxes):
if prediction == "O":
continue
draw.rectangle(box, outline="black")
draw.rectangle(box, outline=label2color[prediction])
draw.text((box[0] + 10, box[1] - 10), text=prediction, fill=label2color[prediction], font=font)
return image
# run inference
def run_inference(image, model=model, processor=processor, output_image=True):
# create model input
encoding = processor(image, return_tensors="pt")
del encoding["pixel_values"]
# run inference
outputs = model(**encoding)
predictions = outputs.logits.argmax(-1).squeeze().tolist()
# get labels
labels = [model.config.id2label[prediction] for prediction in predictions]
if output_image:
return draw_boxes(image, encoding["bbox"][0], labels)
else:
return labels
run_inference(dataset["test"][34]["image"])
Training procedure
Training hyperparameters
The following hyperparameters were used during training:
- learning_rate: 5e-05
- train_batch_size: 8
- eval_batch_size: 8
- seed: 42
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
- lr_scheduler_type: linear
- training_steps: 2500
- mixed_precision_training: Native AMP
Training results
Training Loss | Epoch | Step | Validation Loss | Answer | Header | Question | Overall Precision | Overall Recall | Overall F1 | Overall Accuracy |
---|---|---|---|---|---|---|---|---|---|---|
0.0211 | 10.53 | 200 | 1.5528 | {'precision': 0.8458904109589042, 'recall': 0.9069767441860465, 'f1': 0.8753691671588896, 'number': 817} | {'precision': 0.5684210526315789, 'recall': 0.453781512605042, 'f1': 0.5046728971962617, 'number': 119} | {'precision': 0.896551724137931, 'recall': 0.89322191272052, 'f1': 0.8948837209302325, 'number': 1077} | 0.8596 | 0.8728 | 0.8662 | 0.8011 |
0.0132 | 21.05 | 400 | 1.3143 | {'precision': 0.8447058823529412, 'recall': 0.8788249694002448, 'f1': 0.8614277144571085, 'number': 817} | {'precision': 0.6020408163265306, 'recall': 0.4957983193277311, 'f1': 0.543778801843318, 'number': 119} | {'precision': 0.8854262144821264, 'recall': 0.8969359331476323, 'f1': 0.8911439114391144, 'number': 1077} | 0.8548 | 0.8659 | 0.8603 | 0.8095 |
0.0052 | 31.58 | 600 | 1.5747 | {'precision': 0.8482446206115515, 'recall': 0.9167686658506732, 'f1': 0.8811764705882352, 'number': 817} | {'precision': 0.6283185840707964, 'recall': 0.5966386554621849, 'f1': 0.6120689655172413, 'number': 119} | {'precision': 0.8997161778618732, 'recall': 0.883008356545961, 'f1': 0.8912839737582005, 'number': 1077} | 0.8626 | 0.8798 | 0.8711 | 0.8030 |
0.0073 | 42.11 | 800 | 1.4848 | {'precision': 0.8487972508591065, 'recall': 0.9069767441860465, 'f1': 0.8769230769230769, 'number': 817} | {'precision': 0.5190839694656488, 'recall': 0.5714285714285714, 'f1': 0.5439999999999999, 'number': 119} | {'precision': 0.8941947565543071, 'recall': 0.8867223769730733, 'f1': 0.8904428904428905, 'number': 1077} | 0.8514 | 0.8763 | 0.8636 | 0.7969 |
0.0057 | 52.63 | 1000 | 1.3993 | {'precision': 0.8852071005917159, 'recall': 0.9155446756425949, 'f1': 0.9001203369434416, 'number': 817} | {'precision': 0.5454545454545454, 'recall': 0.6050420168067226, 'f1': 0.5737051792828685, 'number': 119} | {'precision': 0.899090909090909, 'recall': 0.9182915506035283, 'f1': 0.9085898024804776, 'number': 1077} | 0.8710 | 0.8987 | 0.8846 | 0.8198 |
0.0023 | 63.16 | 1200 | 1.6463 | {'precision': 0.8961201501877347, 'recall': 0.8763769889840881, 'f1': 0.886138613861386, 'number': 817} | {'precision': 0.5625, 'recall': 0.5294117647058824, 'f1': 0.5454545454545455, 'number': 119} | {'precision': 0.888, 'recall': 0.9275766016713092, 'f1': 0.9073569482288827, 'number': 1077} | 0.8733 | 0.8833 | 0.8782 | 0.8082 |
0.001 | 73.68 | 1400 | 1.6476 | {'precision': 0.8676814988290398, 'recall': 0.9069767441860465, 'f1': 0.8868940754039496, 'number': 817} | {'precision': 0.6571428571428571, 'recall': 0.5798319327731093, 'f1': 0.6160714285714286, 'number': 119} | {'precision': 0.908256880733945, 'recall': 0.9192200557103064, 'f1': 0.9137055837563451, 'number': 1077} | 0.8785 | 0.8942 | 0.8863 | 0.8137 |
0.0014 | 84.21 | 1600 | 1.6493 | {'precision': 0.8814814814814815, 'recall': 0.8739290085679314, 'f1': 0.8776889981561156, 'number': 817} | {'precision': 0.6194690265486725, 'recall': 0.5882352941176471, 'f1': 0.603448275862069, 'number': 119} | {'precision': 0.894404332129964, 'recall': 0.9201485608170845, 'f1': 0.9070938215102976, 'number': 1077} | 0.8740 | 0.8818 | 0.8778 | 0.8041 |
0.0006 | 94.74 | 1800 | 1.6193 | {'precision': 0.8766467065868263, 'recall': 0.8959608323133414, 'f1': 0.8861985472154963, 'number': 817} | {'precision': 0.6068376068376068, 'recall': 0.5966386554621849, 'f1': 0.6016949152542374, 'number': 119} | {'precision': 0.8946428571428572, 'recall': 0.9303621169916435, 'f1': 0.912152935821575, 'number': 1077} | 0.8711 | 0.8967 | 0.8837 | 0.8137 |
0.0001 | 105.26 | 2000 | 1.6048 | {'precision': 0.8751472320376914, 'recall': 0.9094247246022031, 'f1': 0.8919567827130852, 'number': 817} | {'precision': 0.6140350877192983, 'recall': 0.5882352941176471, 'f1': 0.6008583690987125, 'number': 119} | {'precision': 0.9062784349408554, 'recall': 0.924791086350975, 'f1': 0.9154411764705882, 'number': 1077} | 0.8773 | 0.8987 | 0.8879 | 0.8194 |
0.0001 | 115.79 | 2200 | 1.6117 | {'precision': 0.8821428571428571, 'recall': 0.9069767441860465, 'f1': 0.8943874471937237, 'number': 817} | {'precision': 0.6126126126126126, 'recall': 0.5714285714285714, 'f1': 0.591304347826087, 'number': 119} | {'precision': 0.9045045045045045, 'recall': 0.9322191272051996, 'f1': 0.9181527206218564, 'number': 1077} | 0.8797 | 0.9006 | 0.8900 | 0.8204 |
0.0001 | 126.32 | 2400 | 1.6163 | {'precision': 0.8799048751486326, 'recall': 0.9057527539779682, 'f1': 0.8926417370325694, 'number': 817} | {'precision': 0.6052631578947368, 'recall': 0.5798319327731093, 'f1': 0.5922746781115881, 'number': 119} | {'precision': 0.9062784349408554, 'recall': 0.924791086350975, 'f1': 0.9154411764705882, 'number': 1077} | 0.8788 | 0.8967 | 0.8876 | 0.8192 |
Framework versions
- Transformers 4.24.0
- Pytorch 1.12.1+cu113
- Datasets 2.7.0
- Tokenizers 0.12.1
- Downloads last month
- 118
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.