DeepDiveDev commited on
Commit
ddf7acc
·
verified ·
1 Parent(s): 512f514

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -42
app.py CHANGED
@@ -1,33 +1,33 @@
1
  import torch
2
- import pytesseract
3
  import cv2
4
  import json
5
  import xml.etree.ElementTree as ET
6
- from transformers import TrOCRProcessor, VisionEncoderDecoderModel
7
- from layoutparser import Detectron2LayoutModel
8
- from transformers import GPT2LMHeadModel, GPT2Tokenizer
9
- from stable_baselines3 import PPO
10
 
11
- # Load OCR model
12
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
13
  model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
14
 
 
 
 
 
 
15
  def preprocess_image(image_path):
16
  image = cv2.imread(image_path)
17
  gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
18
  return gray
19
 
 
20
  def extract_text(image_path):
21
  image = preprocess_image(image_path)
22
- text = pytesseract.image_to_string(image)
 
 
23
  return text
24
 
25
- def analyze_layout(image_path):
26
- model = Detectron2LayoutModel("lp://PubLayNet/mask_rcnn_X_101_32x8d_FPN_3x/config")
27
- image = cv2.imread(image_path)
28
- layout = model.detect(image)
29
- return layout
30
-
31
  def generate_machine_readable_format(text, format_type='json'):
32
  if format_type == 'json':
33
  return json.dumps({"content": text})
@@ -38,41 +38,26 @@ def generate_machine_readable_format(text, format_type='json'):
38
  return ET.tostring(root, encoding='unicode')
39
  return text
40
 
41
- # Generative AI Model
42
- GPT2_model = GPT2LMHeadModel.from_pretrained("gpt2")
43
- GPT2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
44
-
45
  def generate_structured_output(text):
46
  inputs = GPT2_tokenizer.encode(text, return_tensors="pt")
47
  outputs = GPT2_model.generate(inputs, max_length=500)
48
  return GPT2_tokenizer.decode(outputs[0])
49
 
50
- # Reinforcement Learning for Optimization
51
- class DocumentConversionEnv:
52
- def __init__(self):
53
- self.state = None
54
-
55
- def reset(self):
56
- self.state = "start"
57
- return self.state
58
-
59
- def step(self, action):
60
- reward = 1 if action == "optimize" else -1
61
- self.state = "optimized" if action == "optimize" else "start"
62
- return self.state, reward, False, {}
63
-
64
- env = DocumentConversionEnv()
65
- rl_model = PPO("MlpPolicy", env, verbose=1)
66
- rl_model.learn(total_timesteps=1000)
67
-
68
- def convert_document(image_path, output_format='json'):
69
- text = extract_text(image_path)
70
- layout = analyze_layout(image_path)
71
  structured_output = generate_structured_output(text)
72
  machine_readable_output = generate_machine_readable_format(structured_output, format_type=output_format)
73
  return machine_readable_output
74
 
75
- # Example usage
76
- document_path = "sample_document.png"
77
- converted_document = convert_document(document_path, output_format='json')
78
- print(converted_document)
 
 
 
 
 
 
 
1
  import torch
 
2
  import cv2
3
  import json
4
  import xml.etree.ElementTree as ET
5
+ import gradio as gr
6
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel, GPT2LMHeadModel, GPT2Tokenizer
 
 
7
 
8
+ # Load OCR model (TrOCR)
9
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
10
  model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
11
 
12
+ # Load GPT-2 model
13
+ GPT2_model = GPT2LMHeadModel.from_pretrained("gpt2")
14
+ GPT2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
15
+
16
+ # Image preprocessing
17
  def preprocess_image(image_path):
18
  image = cv2.imread(image_path)
19
  gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
20
  return gray
21
 
22
+ # Extract text using TrOCR (instead of Tesseract)
23
  def extract_text(image_path):
24
  image = preprocess_image(image_path)
25
+ pixel_values = processor(image, return_tensors="pt").pixel_values
26
+ generated_ids = model.generate(pixel_values)
27
+ text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
28
  return text
29
 
30
+ # Generate structured format (JSON/XML)
 
 
 
 
 
31
  def generate_machine_readable_format(text, format_type='json'):
32
  if format_type == 'json':
33
  return json.dumps({"content": text})
 
38
  return ET.tostring(root, encoding='unicode')
39
  return text
40
 
41
+ # GPT-2 for structured output
 
 
 
42
  def generate_structured_output(text):
43
  inputs = GPT2_tokenizer.encode(text, return_tensors="pt")
44
  outputs = GPT2_model.generate(inputs, max_length=500)
45
  return GPT2_tokenizer.decode(outputs[0])
46
 
47
+ # Convert document
48
+ def convert_document(image, output_format='json'):
49
+ text = extract_text(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  structured_output = generate_structured_output(text)
51
  machine_readable_output = generate_machine_readable_format(structured_output, format_type=output_format)
52
  return machine_readable_output
53
 
54
+ # Gradio UI
55
+ iface = gr.Interface(
56
+ fn=convert_document,
57
+ inputs=[gr.Image(type="filepath"), gr.Radio(["json", "xml"], label="Output Format")],
58
+ outputs="text",
59
+ title="Document OCR and Conversion",
60
+ description="Extracts text from images and converts it into structured JSON/XML format."
61
+ )
62
+
63
+ iface.launch()