--- library_name: transformers license: mit datasets: - SpursgoZmy/MMTab - apoidea/pubtabnet-html language: - en base_model: google/pix2struct-base pipeline_tag: image-to-text --- # pix2struct-base-table2html *Turn table images into HTML!* ## Demo app Try the [demo app]() which contains both table detection and recognition! ## About This model takes an image of a table and outputs HTML - the model parses the image and performs optical character recognition (OCR) and structure recognition to HTML format. The model expects an image containing only a table. If the table is embedded in a document, first use a table detection model to extract it. The model is finetuned from [Pix2Struct base model](https://huggingface.co/google/pix2struct-base) using a max_patch_length of 1024 and max generation length of 1024. The max_patch_length should likely not be changed for inference but the generation length can be changed. The model has been trained using two datasets: [MMTab](https://huggingface.co/datasets/SpursgoZmy/MMTab) and [PubTabNet](https://huggingface.co/datasets/apoidea/pubtabnet-html). ## Usage Below is a complete example of loading the model and performing inference on an example table image (example from the [MMTab dataset](https://huggingface.co/datasets/SpursgoZmy/MMTab)): ```python import torch from transformers import AutoProcessor, Pix2StructForConditionalGeneration from PIL import Image import requests from io import BytesIO # Load model and processor device = "cuda" if torch.cuda.is_available() else "cpu" processor = AutoProcessor.from_pretrained("pix2struct-base-table2html") model = Pix2StructForConditionalGeneration.from_pretrained("pix2struct-base-table2html") model.to(device) model.eval() # Load example image from URL url = "https://example.com/path_to_table_image.jpg" response = requests.get(url) image = Image.open(BytesIO(response.content)) # Run model inference encoding = processor(image, return_tensors="pt", max_patches=1024) with torch.inference_mode(): flattened_patches = encoding.pop("flattened_patches").to(device) attention_mask = encoding.pop("attention_mask").to(device) predictions = model.generate(flattened_patches=flattened_patches, attention_mask=attention_mask, max_new_tokens=1024) predictions_decoded = processor.tokenizer.batch_decode(predictions, skip_special_tokens=True) # Show predictions as text print(predictions_decoded[0]) ```