ruslanmv commited on
Commit
abfefcf
·
1 Parent(s): b429171

First commit

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +158 -0
  3. requirements.txt +7 -0
README.md CHANGED
@@ -7,7 +7,7 @@ sdk: gradio
7
  sdk_version: 4.44.1
8
  app_file: app.py
9
  pinned: false
10
- short_description: Extract tables from images
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
7
  sdk_version: 4.44.1
8
  app_file: app.py
9
  pinned: false
10
+ short_description: Extract tables from images to CSV
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Package installation
2
+ #!pip install git+https://github.com/huggingface/transformers.git
3
+ #!pip install torch, accelerate, bitsandbyte, sentencepiece, pillow
4
+ #!pip install spaces
5
+ import gradio as gr
6
+ import os
7
+ import torch
8
+ from transformers import AutoProcessor, MllamaForConditionalGeneration, TextStreamer
9
+ from PIL import Image
10
+ import csv
11
+ # Check if we're running in a Hugging Face Space and if SPACES_ZERO_GPU is enabled
12
+ IS_SPACES_ZERO = os.environ.get("SPACES_ZERO_GPU", "0") == "1"
13
+ IS_SPACE = os.environ.get("SPACE_ID", None) is not None
14
+ IS_GDRVIE = False
15
+
16
+ # Determine the device (GPU if available, else CPU)
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ LOW_MEMORY = os.getenv("LOW_MEMORY", "0") == "1"
19
+ print(f"Using device: {device}")
20
+ print(f"Low memory mode: {LOW_MEMORY}")
21
+
22
+ # Get Hugging Face token from environment variables
23
+ HF_TOKEN = os.environ.get('HF_TOKEN')
24
+
25
+ # Define the model name
26
+ model_name = "Llama-3.2-11B-Vision-Instruct"
27
+ if IS_GDRVIE:
28
+ # Define the path to the model directory in your Google Drive
29
+ model_path = "/content/drive/MyDrive/models/" + model_name
30
+ model = MllamaForConditionalGeneration.from_pretrained(
31
+ model_path,
32
+ torch_dtype=torch.bfloat16,
33
+ device_map="auto",
34
+ )
35
+ processor = AutoProcessor.from_pretrained(model_path)
36
+ else:
37
+ model_name = "ruslanmv/" + model_name
38
+ model = MllamaForConditionalGeneration.from_pretrained(
39
+ model_name,
40
+ use_auth_token=HF_TOKEN,
41
+ torch_dtype=torch.bfloat16,
42
+ device_map="auto",
43
+ )
44
+ processor = AutoProcessor.from_pretrained(model_name, use_auth_token=HF_TOKEN)
45
+
46
+
47
+
48
+ # Tie the model weights to ensure the model is properly loaded
49
+ if hasattr(model, "tie_weights"):
50
+ model.tie_weights()
51
+
52
+ example = '''Table 1:
53
+ header1,header2,header3
54
+ value1,value2,value3
55
+
56
+ Table 2:
57
+ header1,header2,header3
58
+ value1,value2,value3
59
+ '''
60
+
61
+ prompt_message = """Please extract all tables from the image and generate CSV files.
62
+ Each table should be separated using the format table_n.csv, where n is the table number.
63
+ You must use CSV format with commas as the delimiter. Do not use markdown format. Ensure you use the original table headers and content from the image.
64
+ Only answer with the CSV content. Dont explain the tables.
65
+ An example of the formatting output is as follows:
66
+ """ + example
67
+
68
+
69
+ # Stream LLM response generator
70
+ def stream_response(inputs):
71
+ streamer = TextStreamer(tokenizer=processor.tokenizer)
72
+ for token in model.generate(**inputs, max_new_tokens=2000, do_sample=True, streamer=streamer):
73
+ yield processor.decode(token, skip_special_tokens=True)
74
+
75
+
76
+ @spaces.GPU # Use the free GPU provided by Hugging Face Spaces
77
+ # Predict function for Gradio app
78
+ def predict(message, image):
79
+ # Prepare the input messages
80
+ messages = [
81
+ {"role": "user", "content": [
82
+ {"type": "image"}, # Specify that an image is provided
83
+ {"type": "text", "text": message} # Add the user-provided text input
84
+ ]}
85
+ ]
86
+
87
+ # Create the input text using the processor's chat template
88
+ input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
89
+
90
+ # Process the inputs and move to the appropriate device
91
+ inputs = processor(image, input_text, return_tensors="pt").to(device)
92
+
93
+ # Return a streaming generator of responses
94
+ full_response = ""
95
+ for response in stream_response(inputs):
96
+ # print(response, end="", flush=True) # Print each part of the response as it's generated
97
+ full_response += response
98
+ return extract_and_save_tables(full_response)
99
+
100
+ # Extract tables and save them to CSV
101
+ files_list = []
102
+
103
+ def clean_full_response(full_response):
104
+ """Cleans the full response by removing the prompt input before the tables."""
105
+ # The part of the prompt input to remove
106
+ message_to_remove = prompt_message
107
+ # Remove the message and return only the tables
108
+ return full_response.replace(message_to_remove, "").strip()
109
+
110
+ def extract_and_save_tables(full_response):
111
+ """Extracts CSV tables from the cleaned_response string and saves them as separate files."""
112
+ cleaned_response = clean_full_response(full_response)
113
+ files_list = [] # Initialize the list of file names
114
+ tables = cleaned_response.split("Table ") # Split the response by table sections
115
+
116
+ for i, table in enumerate(tables[1:], start=1): # Start with index 1 for "Table 1"
117
+ table_name = f"table_{i}.csv" # File name for the current table
118
+ rows = table.strip().splitlines()[1:] # Remove "Table n:" line and split the table into rows
119
+ rows = [row.replace('"', '').split(",") for row in rows if row.strip()] # Clean and split by commas
120
+
121
+ # Save the table as a CSV file
122
+ with open(table_name, mode="w", newline='') as file:
123
+ writer = csv.writer(file)
124
+ writer.writerows(rows)
125
+
126
+ files_list.append(table_name) # Append the saved file to the list
127
+
128
+ return files_list
129
+
130
+
131
+ # Gradio interface
132
+ def gradio_app():
133
+ def process_image(image):
134
+ message = prompt_message
135
+ files = predict(message, image)
136
+ return "Tables extracted and saved as CSV files.", files
137
+ # Input components
138
+ image_input = gr.Image(type="pil", label="Upload Image")
139
+
140
+ #message_input = gr.Textbox(lines=2, placeholder="Enter your message", value=message)
141
+ output_text = gr.Textbox(label="Extraction Status")
142
+ file_output = gr.File(label="Download CSV files")
143
+
144
+ # Gradio interface
145
+ iface = gr.Interface(
146
+ fn=process_image,
147
+ inputs=[image_input],
148
+ outputs=[output_text, file_output],
149
+ title="Table Extractor and CSV Converter",
150
+ description="Upload an image to extract tables and download CSV files.",
151
+ allow_flagging="never"
152
+ )
153
+
154
+ iface.launch(debug=True)
155
+
156
+
157
+ # Call the Gradio app function to launch the app
158
+ gradio_app()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ git+https://github.com/huggingface/transformers.git
3
+ torch
4
+ accelerate
5
+ bitsandbytes
6
+ sentencepiece
7
+ Pillow