File size: 6,214 Bytes
abfefcf
 
 
 
 
 
 
 
 
 
9ef77c2
abfefcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ef77c2
 
 
 
 
abfefcf
 
 
9ef77c2
 
abfefcf
 
9ef77c2
 
 
abfefcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#Package installation
#!pip install git+https://github.com/huggingface/transformers.git
#!pip install torch, accelerate, bitsandbyte, sentencepiece, pillow
#!pip install spaces
import gradio as gr
import os
import torch
from transformers import AutoProcessor, MllamaForConditionalGeneration, TextStreamer
from PIL import Image
import csv
import spaces
# Check if we're running in a Hugging Face Space and if SPACES_ZERO_GPU is enabled
IS_SPACES_ZERO = os.environ.get("SPACES_ZERO_GPU", "0") == "1"
IS_SPACE = os.environ.get("SPACE_ID", None) is not None
IS_GDRVIE = False

# Determine the device (GPU if available, else CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
LOW_MEMORY = os.getenv("LOW_MEMORY", "0") == "1"
print(f"Using device: {device}")
print(f"Low memory mode: {LOW_MEMORY}")

# Get Hugging Face token from environment variables
HF_TOKEN = os.environ.get('HF_TOKEN')

# Define the model name
model_name = "Llama-3.2-11B-Vision-Instruct"
if IS_GDRVIE:
    # Define the path to the model directory in your Google Drive
    model_path = "/content/drive/MyDrive/models/" + model_name
    model = MllamaForConditionalGeneration.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )
    processor = AutoProcessor.from_pretrained(model_path)
else:
    # Get Hugging Face token from environment variables
    HF_TOKEN = os.environ.get('HF_TOKEN')

    # Load the model and processor
    model_name = "ruslanmv/Llama-3.2-11B-Vision-Instruct"
    model = MllamaForConditionalGeneration.from_pretrained(
        model_name,
        use_auth_token=HF_TOKEN,
        torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
        device_map="auto" if device == "cuda" else None,  # Use device mapping if CUDA is available
    )

    # Move the model to the appropriate device (GPU if available)
    model.to(device)
    processor = AutoProcessor.from_pretrained(model_name, use_auth_token=HF_TOKEN)


# Tie the model weights to ensure the model is properly loaded
if hasattr(model, "tie_weights"):
    model.tie_weights()

example = '''Table 1:
header1,header2,header3
value1,value2,value3

Table 2:
header1,header2,header3
value1,value2,value3
'''

prompt_message = """Please extract all tables from the image and generate CSV files.
Each table should be separated using the format table_n.csv, where n is the table number.
You must use CSV format with commas as the delimiter. Do not use markdown format. Ensure you use the original table headers and content from the image.
Only answer with the CSV content. Dont explain the tables.
An example of the formatting output is as follows:
""" + example


# Stream LLM response generator
def stream_response(inputs):
    streamer = TextStreamer(tokenizer=processor.tokenizer)
    for token in model.generate(**inputs, max_new_tokens=2000, do_sample=True, streamer=streamer):
        yield processor.decode(token, skip_special_tokens=True)


@spaces.GPU  # Use the free GPU provided by Hugging Face Spaces
# Predict function for Gradio app
def predict(message, image):
    # Prepare the input messages
    messages = [
        {"role": "user", "content": [
            {"type": "image"},  # Specify that an image is provided
            {"type": "text", "text": message}  # Add the user-provided text input
        ]}
    ]

    # Create the input text using the processor's chat template
    input_text = processor.apply_chat_template(messages, add_generation_prompt=True)

    # Process the inputs and move to the appropriate device
    inputs = processor(image, input_text, return_tensors="pt").to(device)

    # Return a streaming generator of responses
    full_response = ""
    for response in stream_response(inputs):
       # print(response, end="", flush=True)  # Print each part of the response as it's generated
        full_response += response 
    return extract_and_save_tables(full_response) 

# Extract tables and save them to CSV
files_list = []

def clean_full_response(full_response):
    """Cleans the full response by removing the prompt input before the tables."""
    # The part of the prompt input to remove
    message_to_remove = prompt_message
    # Remove the message and return only the tables
    return full_response.replace(message_to_remove, "").strip()

def extract_and_save_tables(full_response):
    """Extracts CSV tables from the cleaned_response string and saves them as separate files."""
    cleaned_response = clean_full_response(full_response)
    files_list = []  # Initialize the list of file names
    tables = cleaned_response.split("Table ")  # Split the response by table sections

    for i, table in enumerate(tables[1:], start=1):  # Start with index 1 for "Table 1"
        table_name = f"table_{i}.csv"  # File name for the current table
        rows = table.strip().splitlines()[1:]  # Remove "Table n:" line and split the table into rows
        rows = [row.replace('"', '').split(",") for row in rows if row.strip()]  # Clean and split by commas

        # Save the table as a CSV file
        with open(table_name, mode="w", newline='') as file:
            writer = csv.writer(file)
            writer.writerows(rows)
        
        files_list.append(table_name)  # Append the saved file to the list

    return files_list


# Gradio interface
def gradio_app():
    def process_image(image):
        message = prompt_message
        files = predict(message, image)
        return "Tables extracted and saved as CSV files.", files
    # Input components
    image_input = gr.Image(type="pil", label="Upload Image")

    #message_input = gr.Textbox(lines=2, placeholder="Enter your message", value=message)
    output_text = gr.Textbox(label="Extraction Status")
    file_output = gr.File(label="Download CSV files")

    # Gradio interface
    iface = gr.Interface(
        fn=process_image,
        inputs=[image_input],
        outputs=[output_text, file_output],
        title="Table Extractor and CSV Converter",
        description="Upload an image to extract tables and download CSV files.",
        allow_flagging="never"
    )

    iface.launch(debug=True)


# Call the Gradio app function to launch the app
gradio_app()