danial0203's picture
Update app.py
a41f1ba verified
raw
history blame
4.84 kB
import os
from datasets.filesystems import S3FileSystem
import s3fs
import boto3
from pdf2image import convert_from_path
import csv
from PIL import Image
import gradio as gr
import datetime
# AWS and S3 Initialization with environment variables
aws_access_key_id = os.getenv('AWS_ACCESS_KEY')
aws_secret_access_key = os.getenv('AWS_SECRET_KEY')
region_name = os.getenv('AWS_REGION')
s3_bucket = os.getenv('AWS_BUCKET')
# Properly initialize s3fs with environment variables
s3 = s3fs.S3FileSystem(
key=os.getenv('AWS_ACCESS_KEY'),
secret=os.getenv('AWS_SECRET_KEY'),
client_kwargs={'region_name': os.getenv('AWS_REGION')}
)
# textract_client = boto3.client('textract', region_name=region_name)
textract_client = boto3.client('textract', aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, region_name=region_name)
def upload_file_to_s3(file_path, bucket, object_name=None):
if object_name is None:
object_name = os.path.basename(file_path)
try:
s3_path = f"{bucket}/{object_name}"
s3.upload(file_path, s3_path)
return object_name
except FileNotFoundError:
print("The file was not found")
return None
def process_image(file_path, s3_bucket, textract_client):
s3_object_key = upload_file_to_s3(file_path, s3_bucket)
if not s3_object_key:
return None
response = textract_client.analyze_document(
Document={'S3Object': {'Bucket': s3_bucket, 'Name': s3_object_key}},
FeatureTypes=["TABLES"]
)
return response
def generate_table_csv(tables, blocks_map, csv_output_path):
# Open the CSV file in append mode to add content without overwriting existing data
with open(csv_output_path, 'a', newline='') as csvfile:
writer = csv.writer(csvfile)
for table in tables:
rows = get_rows_columns_map(table, blocks_map)
for row_index, cols in rows.items():
row = []
for col_index in range(1, max(cols.keys()) + 1):
row.append(cols.get(col_index, ""))
writer.writerow(row)
def get_rows_columns_map(table_result, blocks_map):
rows = {}
for relationship in table_result['Relationships']:
if relationship['Type'] == 'CHILD':
for child_id in relationship['Ids']:
cell = blocks_map[child_id]
if 'RowIndex' in cell and 'ColumnIndex' in cell:
row_index = cell['RowIndex']
col_index = cell['ColumnIndex']
if row_index not in rows:
rows[row_index] = {}
rows[row_index][col_index] = get_text(cell, blocks_map)
return rows
def get_text(result, blocks_map):
text = ''
if 'Relationships' in result:
for relationship in result['Relationships']:
if relationship['Type'] == 'CHILD':
for child_id in relationship['Ids']:
word = blocks_map[child_id]
if word['BlockType'] == 'WORD':
text += word['Text'] + ' '
if word['BlockType'] == 'SELECTION_ELEMENT':
if word['SelectionStatus'] == 'SELECTED':
text += 'X '
return text.strip()
def process_file_and_generate_csv(file_path):
# The file_path is directly usable; no need to check for attributes or methods
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
csv_output_path = f"/tmp/output_{timestamp}.csv"
if file_path.lower().endswith(('.png', '.jpg', '.jpeg')):
images = [Image.open(file_path)]
else:
# Convert PDF or other supported formats to images
images = convert_from_path(file_path)
for i, image in enumerate(images):
image_path = f"/tmp/image_{i}.jpg"
image.save(image_path, 'JPEG')
response = process_image(image_path, s3_bucket, textract_client)
if response:
blocks = response['Blocks']
blocks_map = {block['Id']: block for block in blocks}
tables = [block for block in blocks if block['BlockType'] == "TABLE"]
generate_table_csv(tables, blocks_map, csv_output_path)
# No need to remove the original file_path; Gradio handles temporary file cleanup
# Return the CSV output path and a success message for Gradio to handle
return csv_output_path, "Processing completed successfully!"
# Gradio Interface
iface = gr.Interface(
fn=process_file_and_generate_csv,
inputs=gr.File(label="Upload your file (PDF, PNG, JPG, TIFF)"),
outputs=[gr.File(label="Download Generated CSV"), "text"],
description="Upload a document to extract tables into a CSV file."
)
if __name__ == "__main__":
iface.launch()