HycJack commited on
Commit
be14d95
·
1 Parent(s): 1cdb0e0
Files changed (1) hide show
  1. app.py +166 -2
app.py CHANGED
@@ -1,4 +1,168 @@
 
 
 
 
 
 
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # For some reason, transformers decided to use .isin for a simple op, which is not supported on MPS
3
+
4
+ import io
5
+
6
+ import pandas as pd
7
  import streamlit as st
8
+ from streamlit_drawable_canvas import st_canvas
9
+ import hashlib
10
+ import pypdfium2
11
+
12
+ from texify.inference import batch_inference
13
+ from texify.model.model import load_model
14
+ from texify.model.processor import load_processor
15
+ from texify.output import replace_katex_invalid
16
+ from PIL import Image
17
+
18
+ MAX_WIDTH = 800
19
+ MAX_HEIGHT = 1000
20
+
21
+
22
+ @st.cache_resource()
23
+ def load_model_cached():
24
+ return load_model()
25
+
26
+
27
+ @st.cache_resource()
28
+ def load_processor_cached():
29
+ return load_processor()
30
+
31
+
32
+ @st.cache_data()
33
+ def infer_image(pil_image, bbox, temperature):
34
+ input_img = pil_image.crop(bbox)
35
+ model_output = batch_inference([input_img], model, processor, temperature=temperature)
36
+ return model_output[0]
37
+
38
+
39
+ def open_pdf(pdf_file):
40
+ stream = io.BytesIO(pdf_file.getvalue())
41
+ return pypdfium2.PdfDocument(stream)
42
+
43
+
44
+ @st.cache_data()
45
+ def get_page_image(pdf_file, page_num, dpi=96):
46
+ doc = open_pdf(pdf_file)
47
+ renderer = doc.render(
48
+ pypdfium2.PdfBitmap.to_pil,
49
+ page_indices=[page_num - 1],
50
+ scale=dpi / 72,
51
+ )
52
+ png = list(renderer)[0]
53
+ png_image = png.convert("RGB")
54
+ return png_image
55
+
56
+
57
+ @st.cache_data()
58
+ def get_uploaded_image(in_file):
59
+ return Image.open(in_file).convert("RGB")
60
+
61
+
62
+ def resize_image(pil_image):
63
+ if pil_image is None:
64
+ return
65
+ pil_image.thumbnail((MAX_WIDTH, MAX_HEIGHT), Image.Resampling.LANCZOS)
66
+
67
+
68
+ @st.cache_data()
69
+ def page_count(pdf_file):
70
+ doc = open_pdf(pdf_file)
71
+ return len(doc)
72
+
73
+
74
+ def get_canvas_hash(pil_image):
75
+ return hashlib.md5(pil_image.tobytes()).hexdigest()
76
+
77
+
78
+ @st.cache_data()
79
+ def get_image_size(pil_image):
80
+ if pil_image is None:
81
+ return MAX_HEIGHT, MAX_WIDTH
82
+ height, width = pil_image.height, pil_image.width
83
+ return height, width
84
+
85
+
86
+ st.set_page_config(layout="wide")
87
+
88
+ top_message = """### Texify
89
+
90
+ After the model loads, upload an image or a pdf, then draw a box around the equation or text you want to OCR by clicking and dragging. Texify will convert it to Markdown with LaTeX math on the right.
91
+
92
+ If you have already cropped your image, select "OCR image" in the sidebar instead.
93
+ """
94
+
95
+ st.markdown(top_message)
96
+ col1, col2 = st.columns([.7, .3])
97
+
98
+ model = load_model_cached()
99
+ processor = load_processor_cached()
100
+
101
+ in_file = st.sidebar.file_uploader("PDF file or image:", type=["pdf", "png", "jpg", "jpeg", "gif", "webp"])
102
+ if in_file is None:
103
+ st.stop()
104
+
105
+ filetype = in_file.type
106
+ whole_image = False
107
+ if "pdf" in filetype:
108
+ page_count = page_count(in_file)
109
+ page_number = st.sidebar.number_input(f"Page number out of {page_count}:", min_value=1, value=1, max_value=page_count)
110
+
111
+ pil_image = get_page_image(in_file, page_number)
112
+ else:
113
+ pil_image = get_uploaded_image(in_file)
114
+ whole_image = st.sidebar.button("OCR image")
115
+
116
+ # Resize to max bounds
117
+ resize_image(pil_image)
118
+
119
+ temperature = st.sidebar.slider("Generation temperature:", min_value=0.0, max_value=1.0, value=0.0, step=0.05)
120
+
121
+ canvas_hash = get_canvas_hash(pil_image) if pil_image else "canvas"
122
+
123
+ with col1:
124
+ # Create a canvas component
125
+ canvas_result = st_canvas(
126
+ fill_color="rgba(255, 165, 0, 0.1)", # Fixed fill color with some opacity
127
+ stroke_width=1,
128
+ stroke_color="#FFAA00",
129
+ background_color="#FFF",
130
+ background_image=pil_image,
131
+ update_streamlit=True,
132
+ height=get_image_size(pil_image)[0],
133
+ width=get_image_size(pil_image)[1],
134
+ drawing_mode="rect",
135
+ point_display_radius=0,
136
+ key=canvas_hash,
137
+ )
138
+
139
+ if canvas_result.json_data is not None or whole_image:
140
+ objects = pd.json_normalize(canvas_result.json_data["objects"]) # need to convert obj to str because PyArrow
141
+ bbox_list = None
142
+ if objects.shape[0] > 0:
143
+ boxes = objects[objects["type"] == "rect"][["left", "top", "width", "height"]]
144
+ boxes["right"] = boxes["left"] + boxes["width"]
145
+ boxes["bottom"] = boxes["top"] + boxes["height"]
146
+ bbox_list = boxes[["left", "top", "right", "bottom"]].values.tolist()
147
+ if whole_image:
148
+ bbox_list = [(0, 0, pil_image.width, pil_image.height)]
149
+
150
+ if bbox_list:
151
+ with col2:
152
+ inferences = [infer_image(pil_image, bbox, temperature) for bbox in bbox_list]
153
+ for idx, inference in enumerate(reversed(inferences)):
154
+ st.markdown(f"### {len(inferences) - idx}")
155
+ katex_markdown = replace_katex_invalid(inference)
156
+ st.markdown(katex_markdown)
157
+ st.code(inference)
158
+ st.divider()
159
 
160
+ with col2:
161
+ tips = """
162
+ ### Usage tips
163
+ - Don't make your boxes too small or too large. See the examples and the video in the [README](https://github.com/vikParuchuri/texify) for more info.
164
+ - Texify is sensitive to how you draw the box around the text you want to OCR. If you get bad results, try selecting a slightly different box, or splitting the box into multiple.
165
+ - You can try changing the temperature value on the left if you don't get good results. This controls how "creative" the model is.
166
+ - Sometimes KaTeX won't be able to render an equation (red error text), but it will still be valid LaTeX. You can copy the LaTeX and render it elsewhere.
167
+ """
168
+ st.markdown(tips)