romain130492 commited on
Commit
c673e9e
·
verified ·
1 Parent(s): 10889be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -37
app.py CHANGED
@@ -14,6 +14,7 @@ import torch
14
  from PIL import Image
15
  from transformers import DetrImageProcessor, TableTransformerForObjectDetection
16
  from paddleocr import PaddleOCR
 
17
 
18
  ocr = PaddleOCR(use_angle_cls=True, lang="en", use_gpu=False, ocr_version='PP-OCRv4')
19
 
@@ -498,6 +499,30 @@ class TableExtractionPipeline():
498
  c3.markdown(href, unsafe_allow_html=True)
499
 
500
  return df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501
 
502
  async def start_process(self, image_path: str, TD_THRESHOLD, TSR_THRESHOLD,
503
  OCR_THRESHOLD, padd_top, padd_left, padd_bottom,
@@ -575,42 +600,17 @@ class TableExtractionPipeline():
575
  # st.write('Either incorrectly identified table or no table, to debug remove try/except')
576
  # break
577
  # break
578
-
579
-
580
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
581
 
582
- st_up, st_lang = st.columns((1, 1))
583
- img_name = st_up.file_uploader("Upload an image with table(s)")
584
- lang = st_lang.selectbox('Language', ('en', 'japan'))
585
- reload_ocr(lang)
586
-
587
- st1, st2, st3 = st.columns((1, 1, 1))
588
- TD_th = st1.slider('Table detection threshold', 0.0, 1.0, 0.8)
589
- TSR_th = st2.slider('Table structure recognition threshold', 0.0, 1.0, 0.7)
590
- OCR_th = st3.slider("Text Probs Threshold", 0.0, 1.0, 0.5)
591
-
592
- st1, st2, st3, st4 = st.columns((1, 1, 1, 1))
593
-
594
- padd_top = st1.slider('Padding top', 0, 200, 90)
595
- padd_left = st2.slider('Padding left', 0, 200, 40)
596
- padd_right = st3.slider('Padding right', 0, 200, 40)
597
- padd_bottom = st4.slider('Padding bottom', 0, 200, 90)
598
-
599
- te = TableExtractionPipeline()
600
- # for img in image_list:
601
- if img_name is not None:
602
- asyncio.run(
603
- te.start_process(img_name,
604
- TD_THRESHOLD=TD_th,
605
- TSR_THRESHOLD=TSR_th,
606
- OCR_THRESHOLD=OCR_th,
607
- padd_top=padd_top,
608
- padd_left=padd_left,
609
- padd_bottom=padd_bottom,
610
- padd_right=padd_right,
611
- delta_xmin=10, # add offset to the left of the table
612
- delta_ymin=3, # add offset to the bottom of the table
613
- delta_xmax=10, # add offset to the right of the table
614
- delta_ymax=3, # add offset to the top of the table
615
- expand_rowcol_bbox_top=0,
616
- expand_rowcol_bbox_bottom=0))
 
14
  from PIL import Image
15
  from transformers import DetrImageProcessor, TableTransformerForObjectDetection
16
  from paddleocr import PaddleOCR
17
+ import gradio as gr
18
 
19
  ocr = PaddleOCR(use_angle_cls=True, lang="en", use_gpu=False, ocr_version='PP-OCRv4')
20
 
 
499
  c3.markdown(href, unsafe_allow_html=True)
500
 
501
  return df
502
+ def extract_table_json(img, td_th=0.8, tsr_th=0.7, ocr_th=0.5,
503
+ pad_top=90, pad_left=40, pad_bottom=90, pad_right=40):
504
+ # Convert the uploaded PIL Image to a temp file path
505
+ img.save("/tmp/input.png")
506
+ # Call your async pipeline and grab the DataFrame or JSON
507
+ result = asyncio.run(
508
+ te.start_process(
509
+ "/tmp/input.png",
510
+ TD_THRESHOLD=td_th,
511
+ TSR_THRESHOLD=tsr_th,
512
+ OCR_THRESHOLD=ocr_th,
513
+ padd_top=pad_top,
514
+ padd_left=pad_left,
515
+ padd_bottom=pad_bottom,
516
+ padd_right=pad_right,
517
+ delta_xmin=10,
518
+ delta_ymin=3,
519
+ delta_xmax=10,
520
+ delta_ymax=3,
521
+ expand_rowcol_bbox_top=0,
522
+ expand_rowcol_bbox_bottom=0
523
+ )
524
+ )
525
+ return result # make sure your start_process returns JSON/dict
526
 
527
  async def start_process(self, image_path: str, TD_THRESHOLD, TSR_THRESHOLD,
528
  OCR_THRESHOLD, padd_top, padd_left, padd_bottom,
 
600
  # st.write('Either incorrectly identified table or no table, to debug remove try/except')
601
  # break
602
  # break
 
 
603
  if __name__ == "__main__":
604
+ iface = gr.Interface(
605
+ fn=extract_table_json,
606
+ inputs=[
607
+ gr.Image(type="pil", label="Page Image"),
608
+ gr.Slider(0,1,0.8, label="Table-Detection Threshold"),
609
+ gr.Slider(0,1,0.7, label="Structure Threshold"),
610
+ gr.Slider(0,1,0.5, label="OCR Threshold"),
611
+ ],
612
+ outputs=gr.JSON(label="Table JSON"),
613
+ title="Table→CSV JSON API"
614
+ )
615
+ iface.launch(server_name="0.0.0.0", server_port=7860)
616