krishnapal2308 commited on
Commit
0bd5bed
·
1 Parent(s): c66bff2

Initial Commit

Browse files
.idea/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
.idea/DocVQA-Sanctum.iml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ </module>
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="HtmlUnknownTag" enabled="true" level="WARNING" enabled_by_default="true">
5
+ <option name="myValues">
6
+ <value>
7
+ <list size="7">
8
+ <item index="0" class="java.lang.String" itemvalue="nobr" />
9
+ <item index="1" class="java.lang.String" itemvalue="noembed" />
10
+ <item index="2" class="java.lang.String" itemvalue="comment" />
11
+ <item index="3" class="java.lang.String" itemvalue="noscript" />
12
+ <item index="4" class="java.lang.String" itemvalue="embed" />
13
+ <item index="5" class="java.lang.String" itemvalue="script" />
14
+ <item index="6" class="java.lang.String" itemvalue="style" />
15
+ </list>
16
+ </value>
17
+ </option>
18
+ <option name="myCustomValuesEnabled" value="true" />
19
+ </inspection_tool>
20
+ <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
21
+ <option name="ignoredPackages">
22
+ <value>
23
+ <list size="67">
24
+ <item index="0" class="java.lang.String" itemvalue="absl-py" />
25
+ <item index="1" class="java.lang.String" itemvalue="networkx" />
26
+ <item index="2" class="java.lang.String" itemvalue="huggingface-hub" />
27
+ <item index="3" class="java.lang.String" itemvalue="PyYAML" />
28
+ <item index="4" class="java.lang.String" itemvalue="gast" />
29
+ <item index="5" class="java.lang.String" itemvalue="MarkupSafe" />
30
+ <item index="6" class="java.lang.String" itemvalue="numpy" />
31
+ <item index="7" class="java.lang.String" itemvalue="pyasn1" />
32
+ <item index="8" class="java.lang.String" itemvalue="requests" />
33
+ <item index="9" class="java.lang.String" itemvalue="Jinja2" />
34
+ <item index="10" class="java.lang.String" itemvalue="fsspec" />
35
+ <item index="11" class="java.lang.String" itemvalue="pyasn1-modules" />
36
+ <item index="12" class="java.lang.String" itemvalue="safetensors" />
37
+ <item index="13" class="java.lang.String" itemvalue="certifi" />
38
+ <item index="14" class="java.lang.String" itemvalue="keras" />
39
+ <item index="15" class="java.lang.String" itemvalue="urllib3" />
40
+ <item index="16" class="java.lang.String" itemvalue="itsdangerous" />
41
+ <item index="17" class="java.lang.String" itemvalue="Markdown" />
42
+ <item index="18" class="java.lang.String" itemvalue="sympy" />
43
+ <item index="19" class="java.lang.String" itemvalue="Flask" />
44
+ <item index="20" class="java.lang.String" itemvalue="blinker" />
45
+ <item index="21" class="java.lang.String" itemvalue="tokenizers" />
46
+ <item index="22" class="java.lang.String" itemvalue="libclang" />
47
+ <item index="23" class="java.lang.String" itemvalue="transformers" />
48
+ <item index="24" class="java.lang.String" itemvalue="google-auth-oauthlib" />
49
+ <item index="25" class="java.lang.String" itemvalue="Werkzeug" />
50
+ <item index="26" class="java.lang.String" itemvalue="h5py" />
51
+ <item index="27" class="java.lang.String" itemvalue="tensorboard-data-server" />
52
+ <item index="28" class="java.lang.String" itemvalue="packaging" />
53
+ <item index="29" class="java.lang.String" itemvalue="torch" />
54
+ <item index="30" class="java.lang.String" itemvalue="click" />
55
+ <item index="31" class="java.lang.String" itemvalue="tqdm" />
56
+ <item index="32" class="java.lang.String" itemvalue="termcolor" />
57
+ <item index="33" class="java.lang.String" itemvalue="regex" />
58
+ <item index="34" class="java.lang.String" itemvalue="mpmath" />
59
+ <item index="35" class="java.lang.String" itemvalue="typing_extensions" />
60
+ <item index="36" class="java.lang.String" itemvalue="cachetools" />
61
+ <item index="37" class="java.lang.String" itemvalue="charset-normalizer" />
62
+ <item index="38" class="java.lang.String" itemvalue="grpcio" />
63
+ <item index="39" class="java.lang.String" itemvalue="gTTS" />
64
+ <item index="40" class="java.lang.String" itemvalue="google-auth" />
65
+ <item index="41" class="java.lang.String" itemvalue="idna" />
66
+ <item index="42" class="java.lang.String" itemvalue="referencing" />
67
+ <item index="43" class="java.lang.String" itemvalue="tzdata" />
68
+ <item index="44" class="java.lang.String" itemvalue="kiwisolver" />
69
+ <item index="45" class="java.lang.String" itemvalue="rich" />
70
+ <item index="46" class="java.lang.String" itemvalue="cycler" />
71
+ <item index="47" class="java.lang.String" itemvalue="sniffio" />
72
+ <item index="48" class="java.lang.String" itemvalue="markdown-it-py" />
73
+ <item index="49" class="java.lang.String" itemvalue="attrs" />
74
+ <item index="50" class="java.lang.String" itemvalue="contourpy" />
75
+ <item index="51" class="java.lang.String" itemvalue="jsonschema-specifications" />
76
+ <item index="52" class="java.lang.String" itemvalue="pandas" />
77
+ <item index="53" class="java.lang.String" itemvalue="exceptiongroup" />
78
+ <item index="54" class="java.lang.String" itemvalue="fonttools" />
79
+ <item index="55" class="java.lang.String" itemvalue="rpds-py" />
80
+ <item index="56" class="java.lang.String" itemvalue="toolz" />
81
+ <item index="57" class="java.lang.String" itemvalue="Pygments" />
82
+ <item index="58" class="java.lang.String" itemvalue="matplotlib" />
83
+ <item index="59" class="java.lang.String" itemvalue="anyio" />
84
+ <item index="60" class="java.lang.String" itemvalue="pillow" />
85
+ <item index="61" class="java.lang.String" itemvalue="pytz" />
86
+ <item index="62" class="java.lang.String" itemvalue="uvicorn" />
87
+ <item index="63" class="java.lang.String" itemvalue="pyparsing" />
88
+ <item index="64" class="java.lang.String" itemvalue="jsonschema" />
89
+ <item index="65" class="java.lang.String" itemvalue="pydantic_core" />
90
+ <item index="66" class="java.lang.String" itemvalue="gradio_client" />
91
+ </list>
92
+ </value>
93
+ </option>
94
+ </inspection_tool>
95
+ <inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
96
+ <option name="ignoredErrors">
97
+ <list>
98
+ <option value="N801" />
99
+ <option value="N813" />
100
+ </list>
101
+ </option>
102
+ </inspection_tool>
103
+ </profile>
104
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.12 (docvqa_venv)" project-jdk-type="Python SDK" />
4
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/DocVQA-Sanctum.iml" filepath="$PROJECT_DIR$/.idea/DocVQA-Sanctum.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="" vcs="Git" />
5
+ </component>
6
+ </project>
app.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import warnings
3
+ import os
4
+ import pix2struct, layoutlm, donut
5
+ warnings.filterwarnings('ignore')
6
+
7
+
8
+ def process_image_and_generate_output(image, model_selection, question):
9
+ result = ''
10
+ if image is None:
11
+ return "Please select an image", None
12
+
13
+ if model_selection == "LayoutLM":
14
+ result = layoutlm.get_result(image, question)
15
+ return result
16
+ if model_selection == 'Pix2Struct':
17
+ result = pix2struct.get_result(image, question)
18
+ return result
19
+ if model_selection == 'Donut':
20
+ result = donut.get_result(image, question)
21
+ return result
22
+
23
+ return result
24
+
25
+
26
+ sample_images = [
27
+ [os.path.join(os.path.dirname(__file__), "images/1.png"), "LayoutLM", "What is the NIC Code?"],
28
+ [os.path.join(os.path.dirname(__file__), "images/1.png"), "Pix2Struct", "What is the NIC Code?"],
29
+ [os.path.join(os.path.dirname(__file__), "images/1.png"), "Donut", "What is the NIC Code?"]
30
+ ]
31
+
32
+ # Create a dropdown to select sample image
33
+ image_input = gr.Image(label="Upload Image", type='filepath')
34
+
35
+ # Create a dropdown to choose the model
36
+ model_selection_input = gr.Radio(["LayoutLM", "Pix2Struct", "Donut"],
37
+ label="Choose Model")
38
+ question_input = gr.Text(label="Question")
39
+
40
+ iface = gr.Interface(fn=process_image_and_generate_output,
41
+ inputs=[image_input, model_selection_input, question_input],
42
+ outputs=gr.Text(label="Result"),
43
+ allow_flagging='never',
44
+ examples=sample_images,
45
+ title="DocVQA Sanctum")
46
+
47
+ iface.launch()
donut.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from PIL import Image
3
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
4
+
5
+
6
+ def get_result(image_path, question):
7
+ processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
8
+ model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
9
+
10
+ # load document image from the DocVQA dataset
11
+ image = Image.open(image_path).convert('RGB')
12
+
13
+ # prepare decoder inputs
14
+ task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
15
+ prompt = task_prompt.replace("{user_input}", question)
16
+ decoder_input_ids = processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
17
+
18
+ pixel_values = processor(image, return_tensors="pt").pixel_values
19
+
20
+ outputs = model.generate(
21
+ pixel_values,
22
+ decoder_input_ids=decoder_input_ids,
23
+ max_length=model.decoder.config.max_position_embeddings,
24
+ pad_token_id=processor.tokenizer.pad_token_id,
25
+ eos_token_id=processor.tokenizer.eos_token_id,
26
+ use_cache=True,
27
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
28
+ return_dict_in_generate=True,
29
+ )
30
+
31
+ sequence = processor.batch_decode(outputs.sequences)[0]
32
+ sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
33
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
34
+ print(processor.token2json(sequence))
35
+
36
+ return processor.token2json(sequence)['answer']
images/1.png ADDED
layoutlm.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+
3
+
4
+ def get_result(image_path, question):
5
+ nlp = pipeline(
6
+ "document-question-answering",
7
+ model="impira/layoutlm-document-qa",
8
+ )
9
+
10
+ result = nlp(image_path, question)
11
+
12
+ return result[0]['answer']
packages.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ poppler-utils
2
+ tesseract-ocr
3
+ chromium
4
+ chromium-driver
pix2struct.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from transformers import Pix2StructForConditionalGeneration as psg
3
+ from transformers import Pix2StructProcessor as psp
4
+
5
+
6
+ def get_result(image_path, question):
7
+ model = psg.from_pretrained("google/pix2struct-docvqa-large")
8
+ processor = psp.from_pretrained("google/pix2struct-docvqa-large")
9
+
10
+ image = Image.open(image_path).convert("RGB")
11
+ inputs = processor(images=image, text=question, return_tensors="pt")
12
+ predictions = model.generate(**inputs, max_new_tokens=256)
13
+ predicted_answer = processor.batch_decode(predictions, skip_special_tokens=True)
14
+
15
+ return predicted_answer
requirements.txt ADDED
Binary file (2.71 kB). View file