keldrenloy commited on
Commit
a1a6296
·
1 Parent(s): 17dcf46

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +142 -0
  2. requirement.txt +263 -0
app.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from datasets import load_dataset, ClassLabel
4
+ import os
5
+ from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor,LayoutLMv3FeatureExtractor
6
+ import pytesseract
7
+ import numpy as np
8
+ from PIL import ImageDraw, ImageFont
9
+
10
+ os.system('pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu')
11
+ os.system('sudo apt-get install tesseract-ocr')
12
+ os.system('pip install -q pytesseract')
13
+ print("pytesseract:",pytesseract.__version__)
14
+
15
+ examples = [['./examples/example1.png'],['./examples/example2.png'],['./examples/example3.png']]
16
+ dataset = load_dataset("nielsr/cord-layoutlmv3")['train']
17
+
18
+ def get_label_list(labels):
19
+ unique_labels = set()
20
+ for label in labels:
21
+ unique_labels = unique_labels | set(label)
22
+ label_list = list(unique_labels)
23
+ label_list.sort()
24
+ return label_list
25
+
26
+ def convert_l2n_n2l(dataset):
27
+ features = dataset.features
28
+ label_column_name = "ner_tags"
29
+
30
+ label_list = features[label_column_name].feature.names
31
+ if isinstance(features[label_column_name].feature, ClassLabel):
32
+ id2label = {k:v for k,v in enumerate(label_list)}
33
+ label2id = {v:k for k,v in enumerate(label_list)}
34
+ else:
35
+ label_list = get_label_list(dataset[label_column_name])
36
+ id2label = {k:v for k,v in enumerate(label_list)}
37
+ label2id = {v:k for k,v in enumerate(label_list)}
38
+
39
+ return label_list, id2label, label2id, len(label_list)
40
+
41
+ def label_colour(label):
42
+ label2color = {'MENU.PRICE':'blue', 'MENU.NM':'green', 'other':'green','MENU.TOTAL_PRICE':'red'}
43
+ if label in label2color:
44
+ colour = label2color.get(label)
45
+ else:
46
+ colour = None
47
+ return colour
48
+
49
+ def iob_to_label(label):
50
+ label = label[2:]
51
+ if not label:
52
+ return 'other'
53
+ return label
54
+
55
+ def convert_results(words,tags):
56
+ ents = set()
57
+ completeword = ""
58
+ for word, tag in zip(words, tags):
59
+ if tag != "O":
60
+ ent_position, ent_type = tag.split("-")
61
+ if ent_position == "S":
62
+ ents.add((word,ent_type))
63
+ else:
64
+ if ent_position == "B":
65
+ completeword = completeword+ " "+ word
66
+ elif ent_position == "I":
67
+ completeword= completeword+ " " + word
68
+ elif ent_position == "E":
69
+ completeword =completeword+" " + word
70
+
71
+ ents.add((completeword,ent_type))
72
+ completeword= ""
73
+ return ents
74
+
75
+ def unnormalize_box(bbox, width, height):
76
+ return [
77
+ width * (bbox[0] / 1000),
78
+ height * (bbox[1] / 1000),
79
+ width * (bbox[2] / 1000),
80
+ height * (bbox[3] / 1000),
81
+ ]
82
+
83
+ def predict(image):
84
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
85
+ model = LayoutLMv3ForTokenClassification.from_pretrained("keldrenloy/layoutlmv3cordfinetuned").to(device) #add your model directory here
86
+ processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base")
87
+ label_list,id2label,label2id, num_labels = convert_l2n_n2l(dataset)
88
+ width, height = image.size
89
+
90
+ encoding_inputs = processor(image,return_offsets_mapping=True, return_tensors="pt",truncation = True)
91
+ offset_mapping = encoding_inputs.pop('offset_mapping')
92
+ for k,v in encoding_inputs.items():
93
+ encoding_inputs[k] = v.to(device)
94
+
95
+ with torch.no_grad():
96
+ outputs = model(**encoding_inputs)
97
+
98
+ predictions = outputs.logits.argmax(-1).squeeze().tolist()
99
+ token_boxes = encoding_inputs.bbox.squeeze().tolist()
100
+
101
+ is_subword = np.array(offset_mapping.squeeze().tolist())[:,0] != 0
102
+ true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
103
+ true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
104
+
105
+ return true_boxes, true_predictions
106
+
107
+ def text_extraction(image):
108
+ feature_extractor = LayoutLMv3FeatureExtractor()
109
+ encoding = feature_extractor(image, return_tensors="pt")
110
+ return encoding['words'][0]
111
+
112
+ def image_render(image):
113
+ draw = ImageDraw.Draw(image)
114
+ font = ImageFont.load_default()
115
+ true_boxes,true_predictions = predict(image)
116
+
117
+ for prediction, box in zip(true_predictions, true_boxes):
118
+ predicted_label = iob_to_label(prediction)
119
+ draw.rectangle(box, outline=label_colour(predicted_label))
120
+ draw.text((box[0]+10, box[1]-10), text=predicted_label, fill=label_colour(predicted_label), font=font)
121
+
122
+ words = text_extraction(image)
123
+ print(words)
124
+ extracted_words = convert_results(words,true_predictions)
125
+
126
+ return image,extracted_words
127
+
128
+ css = """.output_image, .input_image {height: 600px !important}"""
129
+
130
+ demo = gr.Interface(fn = image_render,
131
+ inputs = gr.inputs.Image(type="pil"),
132
+ outputs = [gr.outputs.Image(type="pil", label="annotated image"),'text'],
133
+ css = css,
134
+ examples = examples,
135
+ allow_flagging=True,
136
+ flagging_options=["incorrect", "correct"],
137
+ flagging_callback = gr.CSVLogger(),
138
+ flagging_dir = "flagged"
139
+ )
140
+
141
+ if __name__ == "__main__":
142
+ demo.launch()
requirement.txt ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.2.0
2
+ accelerate==0.12.0
3
+ aiohttp==3.8.1
4
+ aiosignal==1.2.0
5
+ alembic==1.8.1
6
+ analytics-python==1.4.0
7
+ anyio==3.6.1
8
+ appdirs==1.4.4
9
+ argon2-cffi==21.3.0
10
+ argon2-cffi-bindings==21.2.0
11
+ asgiref==3.5.2
12
+ asttokens==2.0.8
13
+ async-timeout==4.0.2
14
+ attr==0.3.1
15
+ attrs==22.1.0
16
+ azure-core==1.25.1
17
+ azure-storage-blob==12.13.1
18
+ backcall==0.2.0
19
+ backoff==1.10.0
20
+ bcrypt==4.0.0
21
+ beautifulsoup4==4.11.1
22
+ bleach==5.0.1
23
+ boto==2.49.0
24
+ boto3==1.16.63
25
+ botocore==1.19.63
26
+ boxing==0.1.4
27
+ cachetools==4.2.4
28
+ certifi==2022.6.15.1
29
+ cffi==1.15.1
30
+ charset-normalizer==2.0.12
31
+ click==8.1.3
32
+ cloudpickle==2.2.0
33
+ colorama==0.4.5
34
+ contourpy==1.0.5
35
+ coreapi==2.3.3
36
+ coreschema==0.0.4
37
+ cryptography==38.0.1
38
+ cuda-python==11.7.1
39
+ cycler==0.11.0
40
+ Cython==0.29.32
41
+ databricks-cli==0.17.3
42
+ datasets==2.4.0
43
+ debugpy==1.6.3
44
+ decorator==5.1.1
45
+ defusedxml==0.7.1
46
+ Deprecated==1.2.13
47
+ dill==0.3.5.1
48
+ Django==3.1.14
49
+ django-annoying==0.10.6
50
+ django-cors-headers==3.6.0
51
+ django-debug-toolbar==3.2.1
52
+ django-extensions==3.1.0
53
+ django-filter==2.4.0
54
+ django-model-utils==4.1.1
55
+ django-ranged-fileresponse==0.1.2
56
+ django-rest-swagger==2.2.0
57
+ django-rq==2.5.1
58
+ django-user-agents==0.4.0
59
+ djangorestframework==3.13.1
60
+ docker==6.0.0
61
+ docker-pycreds==0.4.0
62
+ docopt==0.6.2
63
+ drf-dynamic-fields==0.3.0
64
+ drf-flex-fields==0.9.5
65
+ drf-generators==0.3.0
66
+ drf-yasg==1.20.0
67
+ entrypoints==0.4
68
+ executing==1.0.0
69
+ expiringdict==1.1.4
70
+ fastapi==0.85.0
71
+ fastjsonschema==2.16.2
72
+ ffmpy==0.3.0
73
+ filelock==3.8.0
74
+ Flask==2.2.2
75
+ fonttools==4.37.3
76
+ frozenlist==1.3.1
77
+ fsspec==2022.8.2
78
+ gitdb==4.0.9
79
+ GitPython==3.1.27
80
+ google-api-core==1.31.5
81
+ google-auth==1.35.0
82
+ google-auth-oauthlib==0.4.6
83
+ google-cloud-appengine-logging==1.1.0
84
+ google-cloud-audit-log==0.2.0
85
+ google-cloud-core==1.5.0
86
+ google-cloud-logging==2.7.2
87
+ google-cloud-storage==1.29.0
88
+ google-resumable-media==0.5.1
89
+ googleapis-common-protos==1.52.0
90
+ gradio==3.3.1
91
+ greenlet==1.1.3
92
+ grpc-google-iam-v1==0.12.3
93
+ grpcio==1.48.1
94
+ h11==0.12.0
95
+ htmlmin==0.1.12
96
+ httpcore==0.15.0
97
+ httpx==0.23.0
98
+ huggingface-hub==0.9.1
99
+ idna==3.3
100
+ importlib-metadata==4.12.0
101
+ inflection==0.5.1
102
+ ipykernel==6.15.2
103
+ ipython==8.5.0
104
+ ipython-genutils==0.2.0
105
+ ipywidgets==8.0.2
106
+ isodate==0.6.1
107
+ itsdangerous==2.1.2
108
+ itypes==1.2.0
109
+ jedi==0.18.1
110
+ Jinja2==3.1.2
111
+ jmespath==0.10.0
112
+ joblib==1.2.0
113
+ jsonschema==3.2.0
114
+ jupyter-core==4.11.1
115
+ jupyter_client==7.3.5
116
+ jupyterlab-pygments==0.2.2
117
+ jupyterlab-widgets==3.0.3
118
+ kiwisolver==1.4.4
119
+ label-studio==1.5.0.post0
120
+ label-studio-converter==0.0.40
121
+ label-studio-tools==0.0.0.dev14
122
+ launchdarkly-server-sdk==7.3.0
123
+ linkify-it-py==1.0.3
124
+ lockfile==0.12.2
125
+ lxml==4.9.1
126
+ Mako==1.2.3
127
+ Markdown==3.4.1
128
+ markdown-it-py==2.1.0
129
+ MarkupSafe==2.1.1
130
+ matplotlib==3.6.0
131
+ matplotlib-inline==0.1.6
132
+ mdit-py-plugins==0.3.0
133
+ mdurl==0.1.2
134
+ mistune==2.0.4
135
+ mlflow==1.29.0
136
+ monotonic==1.6
137
+ msrest==0.7.1
138
+ multidict==6.0.2
139
+ multiprocess==0.70.13
140
+ nbclient==0.6.8
141
+ nbconvert==7.0.0
142
+ nbformat==5.6.0
143
+ nest-asyncio==1.5.5
144
+ nltk==3.6.7
145
+ notebook==6.4.12
146
+ numpy==1.23.3
147
+ nvidia-ml-py3==7.352.0
148
+ oauthlib==3.2.1
149
+ openapi-codec==1.3.2
150
+ ordered-set==4.0.2
151
+ orjson==3.8.0
152
+ packaging==21.3
153
+ pandas==1.3.5
154
+ pandocfilters==1.5.0
155
+ paramiko==2.11.0
156
+ parso==0.8.3
157
+ pathtools==0.1.2
158
+ pickleshare==0.7.5
159
+ Pillow==9.0.1
160
+ pipreqs==0.4.11
161
+ prometheus-client==0.14.1
162
+ prometheus-flask-exporter==0.20.3
163
+ promise==2.3
164
+ prompt-toolkit==3.0.31
165
+ proto-plus==1.22.1
166
+ protobuf==3.19.4
167
+ psutil==5.9.2
168
+ psycopg2-binary==2.9.1
169
+ pure-eval==0.2.2
170
+ pyarrow==9.0.0
171
+ pyasn1==0.4.8
172
+ pyasn1-modules==0.2.8
173
+ pycparser==2.21
174
+ pycryptodome==3.15.0
175
+ pydantic==1.8.2
176
+ pyDeprecate==0.3.2
177
+ pydub==0.25.1
178
+ Pygments==2.13.0
179
+ PyJWT==2.5.0
180
+ PyNaCl==1.5.0
181
+ pyngrok==5.1.0
182
+ pyparsing==3.0.9
183
+ pyRFC3339==1.1
184
+ pyrsistent==0.18.1
185
+ pytesseract==0.3.10
186
+ python-dateutil==2.8.2
187
+ python-multipart==0.0.5
188
+ pytorch-lightning==1.7.5
189
+ pytz==2019.3
190
+ pywin32==304
191
+ pywinpty==2.0.8
192
+ PyYAML==6.0
193
+ pyzmq==23.2.1
194
+ querystring-parser==1.2.4
195
+ redis==4.3.4
196
+ regex==2022.9.11
197
+ requests==2.27.1
198
+ requests-oauthlib==1.3.1
199
+ responses==0.18.0
200
+ rfc3986==1.5.0
201
+ rq==1.10.1
202
+ rsa==4.9
203
+ ruamel.yaml==0.17.21
204
+ ruamel.yaml.clib==0.2.6
205
+ rules==2.2
206
+ s3transfer==0.3.7
207
+ scikit-learn==1.1.2
208
+ scipy==1.9.1
209
+ semver==2.13.0
210
+ Send2Trash==1.8.0
211
+ sentry-sdk==1.9.8
212
+ seqeval==1.2.2
213
+ setproctitle==1.3.2
214
+ shortuuid==1.0.9
215
+ simplejson==3.17.6
216
+ six==1.16.0
217
+ smmap==5.0.0
218
+ sniffio==1.3.0
219
+ soupsieve==2.3.2.post1
220
+ SQLAlchemy==1.4.41
221
+ sqlparse==0.4.2
222
+ stack-data==0.5.0
223
+ starlette==0.20.4
224
+ tabulate==0.8.10
225
+ tensorboard==2.10.0
226
+ tensorboard-data-server==0.6.1
227
+ tensorboard-plugin-wit==1.8.1
228
+ terminado==0.15.0
229
+ tesseract==0.1.3
230
+ threadpoolctl==3.1.0
231
+ tinycss2==1.1.1
232
+ tokenizers==0.12.1
233
+ torch==1.12.1+cu113
234
+ torchaudio==0.12.1+cu113
235
+ torchmetrics==0.9.3
236
+ torchvision==0.13.1+cu113
237
+ tornado==6.2
238
+ tqdm==4.64.1
239
+ traitlets==5.3.0
240
+ transformers==4.21.3
241
+ typing_extensions==4.3.0
242
+ tzdata==2022.2
243
+ ua-parser==0.16.1
244
+ uc-micro-py==1.0.1
245
+ ujson==5.5.0
246
+ uritemplate==4.1.1
247
+ urllib3==1.26.12
248
+ user-agents==2.2.0
249
+ uvicorn==0.18.3
250
+ waitress==2.1.2
251
+ wandb==0.13.3
252
+ wcwidth==0.2.5
253
+ webencodings==0.5.1
254
+ websocket-client==1.4.1
255
+ websockets==10.3
256
+ Werkzeug==2.2.2
257
+ widgetsnbextension==4.0.3
258
+ wrapt==1.14.1
259
+ xmljson==0.2.0
260
+ xxhash==3.0.0
261
+ yarg==0.1.9
262
+ yarl==1.8.1
263
+ zipp==3.8.1