Dileep7729's picture
Update app.py
b443de4 verified
raw
history blame
3.04 kB
import gradio as gr
import torch
from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
import pytesseract
import os
# Explicitly set the Tesseract path for Hugging Face Spaces
pytesseract.pytesseract.tesseract_cmd = "/usr/bin/tesseract"
# Debugging: Print Tesseract version and PATH details
try:
tesseract_version = pytesseract.get_tesseract_version()
print("Tesseract Version:", tesseract_version)
print("Tesseract Path:", pytesseract.pytesseract.tesseract_cmd)
print("Environment PATH:", os.environ["PATH"])
except Exception as e:
print("Tesseract Debugging Error:", e)
# For local development on Windows
# Uncomment the line below if running locally on Windows
# pytesseract.pytesseract.tesseract_cmd = r"C:\Program Files\Tesseract-OCR\tesseract.exe"
# Load the model and processor
processor = LayoutLMv3Processor.from_pretrained("quadranttechnologies/Table_OCR")
model = LayoutLMv3ForTokenClassification.from_pretrained("quadranttechnologies/Table_OCR")
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
def process_image(image):
try:
# Preprocess the image using the processor
encoding = processor(image, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
# Move inputs to the same device as the model
encoding = {key: val.to(device) for key, val in encoding.items()}
# Perform inference
with torch.no_grad():
outputs = model(**encoding)
predictions = torch.argmax(outputs.logits, dim=-1)
# Extract input IDs, bounding boxes, and predicted labels
words = encoding["input_ids"]
bboxes = encoding["bbox"]
labels = predictions.squeeze().tolist()
# Format output as JSON
structured_output = []
for word_id, bbox, label in zip(words.squeeze().tolist(), bboxes.squeeze().tolist(), labels):
# Decode the word ID to text
word = processor.tokenizer.decode([word_id]).strip()
if word: # Avoid adding empty words
structured_output.append({
"word": word,
"bounding_box": bbox,
"label": model.config.id2label[label] # Convert label ID to label name
})
return structured_output
except Exception as e:
# Debugging: Log any errors encountered during processing
print("Error during processing:", str(e))
return {"error": str(e)}
# Define the Gradio interface
interface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil"), # Accepts image input
outputs="json", # Outputs JSON structure
title="Table OCR",
description="Upload an image (e.g., receipt or document) to extract structured information in JSON format."
)
# Launch the app
if __name__ == "__main__":
# Debugging: Check if the app is starting correctly
print("Starting Table OCR App...")
interface.launch(share=True)