Spaces:
Runtime error
Runtime error
Raymond Weitekamp
commited on
Commit
·
32a0510
1
Parent(s):
b840d3e
progress - need to test live now
Browse files- app.py +122 -74
- requirements.txt +2 -1
- test_app.py +5 -5
app.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
2 |
# Import statements that should only run once
|
3 |
if gr.NO_RELOAD:
|
4 |
import random
|
@@ -8,6 +10,7 @@ if gr.NO_RELOAD:
|
|
8 |
from typing import Optional
|
9 |
from PIL import Image # Needed for working with PIL images
|
10 |
import datasets
|
|
|
11 |
|
12 |
# The list of sentences from our previous conversation.
|
13 |
sentences = [
|
@@ -63,6 +66,17 @@ sentences = [
|
|
63 |
"This additional section outlines today's most influential datasets and benchmarks, highlighting how they continue to shape the development of handwriting OCR systems."
|
64 |
]
|
65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
class OCRDataCollector:
|
67 |
def __init__(self):
|
68 |
self.collected_pairs = []
|
@@ -98,17 +112,30 @@ class OCRDataCollector:
|
|
98 |
|
99 |
def strip_metadata(image: Image.Image) -> Image.Image:
|
100 |
"""
|
101 |
-
Helper function to strip all metadata from the provided
|
102 |
-
This creates a new image with the same pixel data but no additional info.
|
103 |
"""
|
|
|
|
|
|
|
|
|
104 |
data = list(image.getdata())
|
105 |
stripped_image = Image.new(image.mode, image.size)
|
106 |
stripped_image.putdata(data)
|
107 |
return stripped_image
|
108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
def create_gradio_interface():
|
111 |
collector = OCRDataCollector()
|
|
|
112 |
|
113 |
with gr.Blocks() as demo:
|
114 |
gr.Markdown("# Handwriting OCR Dataset Creator")
|
@@ -120,15 +147,31 @@ def create_gradio_interface():
|
|
120 |
pass
|
121 |
with gr.Column(scale=2, min_width=200):
|
122 |
login_btn = gr.LoginButton(elem_id="login_btn")
|
|
|
|
|
123 |
user_info = gr.Markdown(
|
124 |
value="<center>Please log in with your Hugging Face account to contribute to the dataset.</center>",
|
125 |
elem_id="user_info"
|
126 |
)
|
127 |
-
|
|
|
128 |
with gr.Column(scale=1):
|
129 |
pass
|
130 |
|
131 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
gr.Markdown(
|
133 |
"### Step 2: Read the text. "
|
134 |
"You will be shown between 1 and 5 consecutive sentences. Please handwrite them on paper and upload an image of your handwriting. "
|
@@ -136,7 +179,6 @@ def create_gradio_interface():
|
|
136 |
"If you wish to skip the current text, click 'Skip'."
|
137 |
)
|
138 |
|
139 |
-
# Main interface elements (initially visible)
|
140 |
text_box = gr.Textbox(
|
141 |
value=collector.current_text_block,
|
142 |
label="Text to Handwrite",
|
@@ -161,16 +203,13 @@ def create_gradio_interface():
|
|
161 |
elem_id="regenerate_btn"
|
162 |
)
|
163 |
|
164 |
-
# Step 3 section
|
165 |
gr.Markdown("### Step 3: Upload an image of your handwritten version of the text")
|
166 |
|
167 |
-
# Message that changes based on login state
|
168 |
upload_info = gr.Markdown(
|
169 |
value="You must be logged in to do this, to help us prevent spam submissions",
|
170 |
elem_id="upload_info"
|
171 |
)
|
172 |
|
173 |
-
# Image upload and related components
|
174 |
image_input = gr.Image(
|
175 |
type="pil",
|
176 |
label="Upload Handwritten Image",
|
@@ -205,9 +244,10 @@ def create_gradio_interface():
|
|
205 |
with gr.Row(visible=False) as button_row:
|
206 |
submit_btn = gr.Button("Submit", elem_id="submit_btn")
|
207 |
|
208 |
-
|
209 |
-
|
210 |
-
|
|
|
211 |
message = "Please upload your handwritten image of the text below." if is_logged_in else "You must be logged in to do this, to help us prevent spam submissions"
|
212 |
|
213 |
return {
|
@@ -217,101 +257,109 @@ def create_gradio_interface():
|
|
217 |
button_row: gr.update(visible=is_logged_in)
|
218 |
}
|
219 |
|
220 |
-
|
221 |
-
|
222 |
-
return "<center>Please log in with your Hugging Face account to contribute to the dataset.</center>", {}
|
223 |
-
return f"<center>Logged in as: {profile['username']}</center>", {"username": profile["username"]}
|
224 |
|
225 |
-
def handle_submit(
|
226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
raise gr.Error("Please log in to use this application")
|
228 |
-
|
|
|
|
|
229 |
|
230 |
-
#
|
231 |
stripped_image = strip_metadata(image)
|
232 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
233 |
-
|
234 |
-
|
235 |
-
'image': datasets.Image(),
|
236 |
-
'timestamp': datasets.Value('string')
|
237 |
-
})
|
238 |
-
temp_dir = "temp"
|
239 |
-
os.makedirs(temp_dir, exist_ok=True)
|
240 |
-
|
241 |
-
# Define targets based on checkboxes: each entry is (dataset_type, repo_id, suffix, privacy_flag)
|
242 |
targets = []
|
243 |
if public_checkbox:
|
244 |
targets.append(("public", "rawwerks/handwriting-ocr-all", "_public", False))
|
245 |
-
|
246 |
-
targets.append(("private", f"{username}/handwriting-ocr-private", "_private", True))
|
247 |
|
248 |
-
|
|
|
|
|
249 |
for ds_type, repo_id, suffix, is_private in targets:
|
250 |
try:
|
251 |
collector.hf_api.dataset_info(repo_id)
|
252 |
-
except Exception
|
253 |
collector.hf_api.create_repo(repo_id, repo_type="dataset", private=is_private)
|
254 |
|
255 |
filename = f"{timestamp}{suffix}.png"
|
256 |
temp_path = os.path.join(temp_dir, filename)
|
257 |
stripped_image.save(temp_path)
|
258 |
|
259 |
-
|
260 |
-
|
261 |
-
'
|
262 |
-
'
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
"dataset": ds_type
|
274 |
})
|
275 |
|
|
|
|
|
|
|
|
|
276 |
new_text = collector.get_random_text_block(max_words)
|
277 |
return None, new_text
|
278 |
|
279 |
-
|
280 |
-
# Remove the login check - allow anyone to regenerate text
|
281 |
-
return collector.get_random_text_block(max_words)
|
282 |
-
|
283 |
-
# On load, update both the display message and the hidden profile state.
|
284 |
-
demo.load(
|
285 |
-
fn=update_user_info,
|
286 |
-
inputs=None,
|
287 |
-
outputs=[user_info, profile_state]
|
288 |
-
)
|
289 |
-
|
290 |
-
# Update UI when login state changes
|
291 |
-
demo.load(
|
292 |
-
fn=update_ui_visibility,
|
293 |
-
inputs=None,
|
294 |
-
outputs=[
|
295 |
-
upload_info,
|
296 |
-
image_input,
|
297 |
-
dataset_options,
|
298 |
-
button_row
|
299 |
-
]
|
300 |
-
)
|
301 |
-
|
302 |
-
# Bind the submit and skip actions
|
303 |
submit_btn.click(
|
304 |
fn=handle_submit,
|
305 |
inputs=[
|
306 |
-
|
307 |
-
image_input,
|
|
|
|
|
|
|
308 |
],
|
309 |
outputs=[image_input, text_box]
|
310 |
)
|
311 |
-
|
|
|
|
|
|
|
|
|
312 |
regenerate_btn.click(
|
313 |
fn=handle_regenerate,
|
314 |
-
inputs=[
|
315 |
outputs=text_box
|
316 |
)
|
317 |
|
|
|
1 |
import gradio as gr
|
2 |
+
from pydantic import BaseModel, Field
|
3 |
+
from typing import Optional, Any
|
4 |
# Import statements that should only run once
|
5 |
if gr.NO_RELOAD:
|
6 |
import random
|
|
|
10 |
from typing import Optional
|
11 |
from PIL import Image # Needed for working with PIL images
|
12 |
import datasets
|
13 |
+
import numpy as np # Added to help handle numpy array images
|
14 |
|
15 |
# The list of sentences from our previous conversation.
|
16 |
sentences = [
|
|
|
66 |
"This additional section outlines today's most influential datasets and benchmarks, highlighting how they continue to shape the development of handwriting OCR systems."
|
67 |
]
|
68 |
|
69 |
+
class SubmissionData(BaseModel):
|
70 |
+
text: str = Field(..., description="Text to be handwritten")
|
71 |
+
profile: Any = Field(..., description="Gradio OAuth profile")
|
72 |
+
image: Optional[Image.Image] = Field(None, description="Uploaded handwritten image")
|
73 |
+
max_words: int = Field(..., ge=1, le=201, description="Maximum number of words")
|
74 |
+
public_checkbox: bool = Field(..., description="Submit to public dataset")
|
75 |
+
|
76 |
+
model_config = {
|
77 |
+
"arbitrary_types_allowed": True # Allow PIL.Image.Image type
|
78 |
+
}
|
79 |
+
|
80 |
class OCRDataCollector:
|
81 |
def __init__(self):
|
82 |
self.collected_pairs = []
|
|
|
112 |
|
113 |
def strip_metadata(image: Image.Image) -> Image.Image:
|
114 |
"""
|
115 |
+
Helper function to strip all metadata from the provided image data.
|
|
|
116 |
"""
|
117 |
+
if image is None:
|
118 |
+
raise gr.Error("No valid image provided")
|
119 |
+
|
120 |
+
# Create a new image with the same pixel data but no metadata
|
121 |
data = list(image.getdata())
|
122 |
stripped_image = Image.new(image.mode, image.size)
|
123 |
stripped_image.putdata(data)
|
124 |
return stripped_image
|
125 |
|
126 |
+
class UserState:
|
127 |
+
def __init__(self):
|
128 |
+
self.username = None
|
129 |
+
self.is_logged_in = False
|
130 |
+
|
131 |
+
def update_from_profile(self, profile: gr.OAuthProfile | None) -> None:
|
132 |
+
"""Update user state from Gradio OAuth profile"""
|
133 |
+
self.is_logged_in = profile is not None and getattr(profile, "username", None) is not None
|
134 |
+
self.username = profile.username if self.is_logged_in else None
|
135 |
|
136 |
def create_gradio_interface():
|
137 |
collector = OCRDataCollector()
|
138 |
+
user_state = UserState()
|
139 |
|
140 |
with gr.Blocks() as demo:
|
141 |
gr.Markdown("# Handwriting OCR Dataset Creator")
|
|
|
147 |
pass
|
148 |
with gr.Column(scale=2, min_width=200):
|
149 |
login_btn = gr.LoginButton(elem_id="login_btn")
|
150 |
+
# Activate the login button so OAuth is correctly initialized.
|
151 |
+
login_btn.activate()
|
152 |
user_info = gr.Markdown(
|
153 |
value="<center>Please log in with your Hugging Face account to contribute to the dataset.</center>",
|
154 |
elem_id="user_info"
|
155 |
)
|
156 |
+
# Create a hidden state component to store the OAuth profile.
|
157 |
+
profile_state = gr.State()
|
158 |
with gr.Column(scale=1):
|
159 |
pass
|
160 |
|
161 |
+
# Update user info based on the OAuth profile.
|
162 |
+
def update_user_info(profile: gr.OAuthProfile | None) -> str:
|
163 |
+
if profile and getattr(profile, "username", None):
|
164 |
+
return f"<center>Logged in as: {profile.username}</center>"
|
165 |
+
else:
|
166 |
+
return "<center>Please log in with your Hugging Face account to contribute to the dataset.</center>"
|
167 |
+
|
168 |
+
demo.load(update_user_info, inputs=None, outputs=user_info)
|
169 |
+
|
170 |
+
# Store the OAuth profile in the hidden state.
|
171 |
+
def store_profile(profile: gr.OAuthProfile | None) -> gr.OAuthProfile | None:
|
172 |
+
return profile
|
173 |
+
demo.load(store_profile, inputs=None, outputs=profile_state)
|
174 |
+
|
175 |
gr.Markdown(
|
176 |
"### Step 2: Read the text. "
|
177 |
"You will be shown between 1 and 5 consecutive sentences. Please handwrite them on paper and upload an image of your handwriting. "
|
|
|
179 |
"If you wish to skip the current text, click 'Skip'."
|
180 |
)
|
181 |
|
|
|
182 |
text_box = gr.Textbox(
|
183 |
value=collector.current_text_block,
|
184 |
label="Text to Handwrite",
|
|
|
203 |
elem_id="regenerate_btn"
|
204 |
)
|
205 |
|
|
|
206 |
gr.Markdown("### Step 3: Upload an image of your handwritten version of the text")
|
207 |
|
|
|
208 |
upload_info = gr.Markdown(
|
209 |
value="You must be logged in to do this, to help us prevent spam submissions",
|
210 |
elem_id="upload_info"
|
211 |
)
|
212 |
|
|
|
213 |
image_input = gr.Image(
|
214 |
type="pil",
|
215 |
label="Upload Handwritten Image",
|
|
|
244 |
with gr.Row(visible=False) as button_row:
|
245 |
submit_btn = gr.Button("Submit", elem_id="submit_btn")
|
246 |
|
247 |
+
# Update user state when profile changes
|
248 |
+
def update_user_state(profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None = None, *args):
|
249 |
+
user_state.update_from_profile(profile)
|
250 |
+
is_logged_in = user_state.is_logged_in
|
251 |
message = "Please upload your handwritten image of the text below." if is_logged_in else "You must be logged in to do this, to help us prevent spam submissions"
|
252 |
|
253 |
return {
|
|
|
257 |
button_row: gr.update(visible=is_logged_in)
|
258 |
}
|
259 |
|
260 |
+
# Load initial state and update UI visibility
|
261 |
+
demo.load(update_user_state, inputs=profile_state, outputs=[upload_info, image_input, dataset_options, button_row])
|
|
|
|
|
262 |
|
263 |
+
def handle_submit(
|
264 |
+
text: str,
|
265 |
+
image: Image.Image,
|
266 |
+
max_words: int,
|
267 |
+
public_checkbox: bool,
|
268 |
+
collector: OCRDataCollector | None = None,
|
269 |
+
*args
|
270 |
+
):
|
271 |
+
"""Handle submission with clean parameter order"""
|
272 |
+
print(f"Debug - Initial params:")
|
273 |
+
print(f"Text: {text[:50]}")
|
274 |
+
print(f"Image type: {type(image)}")
|
275 |
+
print(f"Max words: {max_words}")
|
276 |
+
print(f"Public checkbox: {public_checkbox}")
|
277 |
+
print(f"Collector type: {type(collector)}")
|
278 |
+
|
279 |
+
if collector is None:
|
280 |
+
raise gr.Error("Internal error: OCR collector not initialized")
|
281 |
+
|
282 |
+
if not user_state.is_logged_in:
|
283 |
raise gr.Error("Please log in to use this application")
|
284 |
+
|
285 |
+
if not isinstance(image, Image.Image):
|
286 |
+
raise gr.Error("Please upload a valid image before submitting")
|
287 |
|
288 |
+
# Strip metadata from validated image
|
289 |
stripped_image = strip_metadata(image)
|
290 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
291 |
+
|
292 |
+
# Define targets based on checkboxes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
targets = []
|
294 |
if public_checkbox:
|
295 |
targets.append(("public", "rawwerks/handwriting-ocr-all", "_public", False))
|
296 |
+
targets.append(("private", f"{user_state.username}/handwriting-ocr-private", "_private", True))
|
|
|
297 |
|
298 |
+
temp_dir = "temp"
|
299 |
+
os.makedirs(temp_dir, exist_ok=True)
|
300 |
+
|
301 |
for ds_type, repo_id, suffix, is_private in targets:
|
302 |
try:
|
303 |
collector.hf_api.dataset_info(repo_id)
|
304 |
+
except Exception:
|
305 |
collector.hf_api.create_repo(repo_id, repo_type="dataset", private=is_private)
|
306 |
|
307 |
filename = f"{timestamp}{suffix}.png"
|
308 |
temp_path = os.path.join(temp_dir, filename)
|
309 |
stripped_image.save(temp_path)
|
310 |
|
311 |
+
# Define features to properly handle image files
|
312 |
+
features = datasets.Features({
|
313 |
+
'text': datasets.Value('string'),
|
314 |
+
'image': datasets.Image(),
|
315 |
+
'timestamp': datasets.Value('string')
|
316 |
+
})
|
317 |
+
|
318 |
+
try:
|
319 |
+
# Try to load existing dataset
|
320 |
+
dataset = datasets.load_dataset(repo_id, split="train")
|
321 |
+
except Exception:
|
322 |
+
# If no existing dataset, create a new empty one
|
323 |
+
dataset = datasets.Dataset.from_dict({
|
324 |
+
'text': [],
|
325 |
+
'image': [],
|
326 |
+
'timestamp': []
|
327 |
+
}, features=features)
|
328 |
|
329 |
+
# Add the new item
|
330 |
+
dataset = dataset.add_item({
|
331 |
+
'text': text,
|
332 |
+
'image': temp_path,
|
333 |
+
'timestamp': timestamp
|
|
|
334 |
})
|
335 |
|
336 |
+
# Push updates to hub
|
337 |
+
dataset.push_to_hub(repo_id, split="train")
|
338 |
+
os.remove(temp_path)
|
339 |
+
|
340 |
new_text = collector.get_random_text_block(max_words)
|
341 |
return None, new_text
|
342 |
|
343 |
+
# Submit button click handler with simplified inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
344 |
submit_btn.click(
|
345 |
fn=handle_submit,
|
346 |
inputs=[
|
347 |
+
text_box, # Text to handwrite
|
348 |
+
image_input, # Uploaded image
|
349 |
+
max_words_slider, # Max words
|
350 |
+
public_checkbox, # Public dataset option
|
351 |
+
gr.State(collector) # Pass the collector instance
|
352 |
],
|
353 |
outputs=[image_input, text_box]
|
354 |
)
|
355 |
+
|
356 |
+
def handle_regenerate(text, max_words):
|
357 |
+
# Allow anyone to regenerate text regardless of login status.
|
358 |
+
return collector.get_random_text_block(max_words)
|
359 |
+
|
360 |
regenerate_btn.click(
|
361 |
fn=handle_regenerate,
|
362 |
+
inputs=[text_box, max_words_slider],
|
363 |
outputs=text_box
|
364 |
)
|
365 |
|
requirements.txt
CHANGED
@@ -5,4 +5,5 @@ pytest>=7.0.0
|
|
5 |
pytest-playwright>=0.4.0
|
6 |
pytest-asyncio>=0.23.0
|
7 |
playwright>=1.40.0
|
8 |
-
datasets>=2.16.0
|
|
|
|
5 |
pytest-playwright>=0.4.0
|
6 |
pytest-asyncio>=0.23.0
|
7 |
playwright>=1.40.0
|
8 |
+
datasets>=2.16.0
|
9 |
+
pydantic>=2.6.1
|
test_app.py
CHANGED
@@ -10,7 +10,7 @@ def collector():
|
|
10 |
|
11 |
def test_get_random_text_block(collector):
|
12 |
# Test that we get a non-empty string
|
13 |
-
text_block = collector.get_random_text_block()
|
14 |
assert isinstance(text_block, str)
|
15 |
assert len(text_block) > 0
|
16 |
|
@@ -18,13 +18,13 @@ def test_get_random_text_block(collector):
|
|
18 |
assert any(sentence in text_block for sentence in sentences)
|
19 |
|
20 |
# Test that we get different blocks (probabilistic, but very likely)
|
21 |
-
blocks = [collector.get_random_text_block() for _ in range(5)]
|
22 |
assert len(set(blocks)) > 1, "Random blocks should be different"
|
23 |
|
24 |
def test_skip_text(collector):
|
25 |
# Test that we get a different text block when skipping
|
26 |
-
current_text = collector.get_random_text_block()
|
27 |
-
new_text = collector.get_random_text_block()
|
28 |
|
29 |
assert isinstance(new_text, str)
|
30 |
assert len(new_text) > 0
|
@@ -39,7 +39,7 @@ def test_submit_image(collector):
|
|
39 |
test_image = Image.fromarray(img_array)
|
40 |
|
41 |
# Test the current text block
|
42 |
-
current_text = collector.get_random_text_block()
|
43 |
|
44 |
# Test submission with valid image
|
45 |
new_text = collector.submit_image(test_image, current_text)
|
|
|
10 |
|
11 |
def test_get_random_text_block(collector):
|
12 |
# Test that we get a non-empty string
|
13 |
+
text_block = collector.get_random_text_block(max_words=50)
|
14 |
assert isinstance(text_block, str)
|
15 |
assert len(text_block) > 0
|
16 |
|
|
|
18 |
assert any(sentence in text_block for sentence in sentences)
|
19 |
|
20 |
# Test that we get different blocks (probabilistic, but very likely)
|
21 |
+
blocks = [collector.get_random_text_block(max_words=50) for _ in range(5)]
|
22 |
assert len(set(blocks)) > 1, "Random blocks should be different"
|
23 |
|
24 |
def test_skip_text(collector):
|
25 |
# Test that we get a different text block when skipping
|
26 |
+
current_text = collector.get_random_text_block(max_words=50)
|
27 |
+
new_text = collector.get_random_text_block(max_words=50)
|
28 |
|
29 |
assert isinstance(new_text, str)
|
30 |
assert len(new_text) > 0
|
|
|
39 |
test_image = Image.fromarray(img_array)
|
40 |
|
41 |
# Test the current text block
|
42 |
+
current_text = collector.get_random_text_block(max_words=50)
|
43 |
|
44 |
# Test submission with valid image
|
45 |
new_text = collector.submit_image(test_image, current_text)
|