Raymond Weitekamp commited on
Commit
32a0510
·
1 Parent(s): b840d3e

progress - need to test live now

Browse files
Files changed (3) hide show
  1. app.py +122 -74
  2. requirements.txt +2 -1
  3. test_app.py +5 -5
app.py CHANGED
@@ -1,4 +1,6 @@
1
  import gradio as gr
 
 
2
  # Import statements that should only run once
3
  if gr.NO_RELOAD:
4
  import random
@@ -8,6 +10,7 @@ if gr.NO_RELOAD:
8
  from typing import Optional
9
  from PIL import Image # Needed for working with PIL images
10
  import datasets
 
11
 
12
  # The list of sentences from our previous conversation.
13
  sentences = [
@@ -63,6 +66,17 @@ sentences = [
63
  "This additional section outlines today's most influential datasets and benchmarks, highlighting how they continue to shape the development of handwriting OCR systems."
64
  ]
65
 
 
 
 
 
 
 
 
 
 
 
 
66
  class OCRDataCollector:
67
  def __init__(self):
68
  self.collected_pairs = []
@@ -98,17 +112,30 @@ class OCRDataCollector:
98
 
99
  def strip_metadata(image: Image.Image) -> Image.Image:
100
  """
101
- Helper function to strip all metadata from the provided PIL Image.
102
- This creates a new image with the same pixel data but no additional info.
103
  """
 
 
 
 
104
  data = list(image.getdata())
105
  stripped_image = Image.new(image.mode, image.size)
106
  stripped_image.putdata(data)
107
  return stripped_image
108
 
 
 
 
 
 
 
 
 
 
109
 
110
  def create_gradio_interface():
111
  collector = OCRDataCollector()
 
112
 
113
  with gr.Blocks() as demo:
114
  gr.Markdown("# Handwriting OCR Dataset Creator")
@@ -120,15 +147,31 @@ def create_gradio_interface():
120
  pass
121
  with gr.Column(scale=2, min_width=200):
122
  login_btn = gr.LoginButton(elem_id="login_btn")
 
 
123
  user_info = gr.Markdown(
124
  value="<center>Please log in with your Hugging Face account to contribute to the dataset.</center>",
125
  elem_id="user_info"
126
  )
127
- profile_state = gr.JSON(visible=False, elem_id="profile_state")
 
128
  with gr.Column(scale=1):
129
  pass
130
 
131
- # Instructions (always visible)
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  gr.Markdown(
133
  "### Step 2: Read the text. "
134
  "You will be shown between 1 and 5 consecutive sentences. Please handwrite them on paper and upload an image of your handwriting. "
@@ -136,7 +179,6 @@ def create_gradio_interface():
136
  "If you wish to skip the current text, click 'Skip'."
137
  )
138
 
139
- # Main interface elements (initially visible)
140
  text_box = gr.Textbox(
141
  value=collector.current_text_block,
142
  label="Text to Handwrite",
@@ -161,16 +203,13 @@ def create_gradio_interface():
161
  elem_id="regenerate_btn"
162
  )
163
 
164
- # Step 3 section
165
  gr.Markdown("### Step 3: Upload an image of your handwritten version of the text")
166
 
167
- # Message that changes based on login state
168
  upload_info = gr.Markdown(
169
  value="You must be logged in to do this, to help us prevent spam submissions",
170
  elem_id="upload_info"
171
  )
172
 
173
- # Image upload and related components
174
  image_input = gr.Image(
175
  type="pil",
176
  label="Upload Handwritten Image",
@@ -205,9 +244,10 @@ def create_gradio_interface():
205
  with gr.Row(visible=False) as button_row:
206
  submit_btn = gr.Button("Submit", elem_id="submit_btn")
207
 
208
- def update_ui_visibility(profile: gr.OAuthProfile | None) -> dict:
209
- """Update visibility of UI elements based on login state"""
210
- is_logged_in = profile is not None
 
211
  message = "Please upload your handwritten image of the text below." if is_logged_in else "You must be logged in to do this, to help us prevent spam submissions"
212
 
213
  return {
@@ -217,101 +257,109 @@ def create_gradio_interface():
217
  button_row: gr.update(visible=is_logged_in)
218
  }
219
 
220
- def update_user_info(profile: Optional[dict]) -> tuple[str, dict]:
221
- if profile is None:
222
- return "<center>Please log in with your Hugging Face account to contribute to the dataset.</center>", {}
223
- return f"<center>Logged in as: {profile['username']}</center>", {"username": profile["username"]}
224
 
225
- def handle_submit(profile, private_checkbox, public_checkbox, image, text, max_words):
226
- if not profile or "username" not in profile:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  raise gr.Error("Please log in to use this application")
228
- username = profile["username"]
 
 
229
 
230
- # Common processing: strip metadata, get timestamp, create features, and setup temp directory.
231
  stripped_image = strip_metadata(image)
232
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
233
- features = datasets.Features({
234
- 'text': datasets.Value('string'),
235
- 'image': datasets.Image(),
236
- 'timestamp': datasets.Value('string')
237
- })
238
- temp_dir = "temp"
239
- os.makedirs(temp_dir, exist_ok=True)
240
-
241
- # Define targets based on checkboxes: each entry is (dataset_type, repo_id, suffix, privacy_flag)
242
  targets = []
243
  if public_checkbox:
244
  targets.append(("public", "rawwerks/handwriting-ocr-all", "_public", False))
245
- if private_checkbox:
246
- targets.append(("private", f"{username}/handwriting-ocr-private", "_private", True))
247
 
248
- # Loop over each target, pushing the dataset with shared logic.
 
 
249
  for ds_type, repo_id, suffix, is_private in targets:
250
  try:
251
  collector.hf_api.dataset_info(repo_id)
252
- except Exception as e:
253
  collector.hf_api.create_repo(repo_id, repo_type="dataset", private=is_private)
254
 
255
  filename = f"{timestamp}{suffix}.png"
256
  temp_path = os.path.join(temp_dir, filename)
257
  stripped_image.save(temp_path)
258
 
259
- dataset_dict = {
260
- 'text': [text],
261
- 'image': [temp_path],
262
- 'timestamp': [timestamp]
263
- }
264
- dataset = datasets.Dataset.from_dict(dataset_dict, features=features)
265
- dataset.push_to_hub(repo_id)
266
- os.remove(temp_path)
 
 
 
 
 
 
 
 
 
267
 
268
- collector.collected_pairs.append({
269
- "text": text,
270
- "image": image,
271
- "timestamp": timestamp,
272
- "username": username,
273
- "dataset": ds_type
274
  })
275
 
 
 
 
 
276
  new_text = collector.get_random_text_block(max_words)
277
  return None, new_text
278
 
279
- def handle_regenerate(profile, text, max_words):
280
- # Remove the login check - allow anyone to regenerate text
281
- return collector.get_random_text_block(max_words)
282
-
283
- # On load, update both the display message and the hidden profile state.
284
- demo.load(
285
- fn=update_user_info,
286
- inputs=None,
287
- outputs=[user_info, profile_state]
288
- )
289
-
290
- # Update UI when login state changes
291
- demo.load(
292
- fn=update_ui_visibility,
293
- inputs=None,
294
- outputs=[
295
- upload_info,
296
- image_input,
297
- dataset_options,
298
- button_row
299
- ]
300
- )
301
-
302
- # Bind the submit and skip actions
303
  submit_btn.click(
304
  fn=handle_submit,
305
  inputs=[
306
- profile_state, private_checkbox, public_checkbox,
307
- image_input, text_box, max_words_slider
 
 
 
308
  ],
309
  outputs=[image_input, text_box]
310
  )
311
-
 
 
 
 
312
  regenerate_btn.click(
313
  fn=handle_regenerate,
314
- inputs=[profile_state, text_box, max_words_slider],
315
  outputs=text_box
316
  )
317
 
 
1
  import gradio as gr
2
+ from pydantic import BaseModel, Field
3
+ from typing import Optional, Any
4
  # Import statements that should only run once
5
  if gr.NO_RELOAD:
6
  import random
 
10
  from typing import Optional
11
  from PIL import Image # Needed for working with PIL images
12
  import datasets
13
+ import numpy as np # Added to help handle numpy array images
14
 
15
  # The list of sentences from our previous conversation.
16
  sentences = [
 
66
  "This additional section outlines today's most influential datasets and benchmarks, highlighting how they continue to shape the development of handwriting OCR systems."
67
  ]
68
 
69
+ class SubmissionData(BaseModel):
70
+ text: str = Field(..., description="Text to be handwritten")
71
+ profile: Any = Field(..., description="Gradio OAuth profile")
72
+ image: Optional[Image.Image] = Field(None, description="Uploaded handwritten image")
73
+ max_words: int = Field(..., ge=1, le=201, description="Maximum number of words")
74
+ public_checkbox: bool = Field(..., description="Submit to public dataset")
75
+
76
+ model_config = {
77
+ "arbitrary_types_allowed": True # Allow PIL.Image.Image type
78
+ }
79
+
80
  class OCRDataCollector:
81
  def __init__(self):
82
  self.collected_pairs = []
 
112
 
113
  def strip_metadata(image: Image.Image) -> Image.Image:
114
  """
115
+ Helper function to strip all metadata from the provided image data.
 
116
  """
117
+ if image is None:
118
+ raise gr.Error("No valid image provided")
119
+
120
+ # Create a new image with the same pixel data but no metadata
121
  data = list(image.getdata())
122
  stripped_image = Image.new(image.mode, image.size)
123
  stripped_image.putdata(data)
124
  return stripped_image
125
 
126
+ class UserState:
127
+ def __init__(self):
128
+ self.username = None
129
+ self.is_logged_in = False
130
+
131
+ def update_from_profile(self, profile: gr.OAuthProfile | None) -> None:
132
+ """Update user state from Gradio OAuth profile"""
133
+ self.is_logged_in = profile is not None and getattr(profile, "username", None) is not None
134
+ self.username = profile.username if self.is_logged_in else None
135
 
136
  def create_gradio_interface():
137
  collector = OCRDataCollector()
138
+ user_state = UserState()
139
 
140
  with gr.Blocks() as demo:
141
  gr.Markdown("# Handwriting OCR Dataset Creator")
 
147
  pass
148
  with gr.Column(scale=2, min_width=200):
149
  login_btn = gr.LoginButton(elem_id="login_btn")
150
+ # Activate the login button so OAuth is correctly initialized.
151
+ login_btn.activate()
152
  user_info = gr.Markdown(
153
  value="<center>Please log in with your Hugging Face account to contribute to the dataset.</center>",
154
  elem_id="user_info"
155
  )
156
+ # Create a hidden state component to store the OAuth profile.
157
+ profile_state = gr.State()
158
  with gr.Column(scale=1):
159
  pass
160
 
161
+ # Update user info based on the OAuth profile.
162
+ def update_user_info(profile: gr.OAuthProfile | None) -> str:
163
+ if profile and getattr(profile, "username", None):
164
+ return f"<center>Logged in as: {profile.username}</center>"
165
+ else:
166
+ return "<center>Please log in with your Hugging Face account to contribute to the dataset.</center>"
167
+
168
+ demo.load(update_user_info, inputs=None, outputs=user_info)
169
+
170
+ # Store the OAuth profile in the hidden state.
171
+ def store_profile(profile: gr.OAuthProfile | None) -> gr.OAuthProfile | None:
172
+ return profile
173
+ demo.load(store_profile, inputs=None, outputs=profile_state)
174
+
175
  gr.Markdown(
176
  "### Step 2: Read the text. "
177
  "You will be shown between 1 and 5 consecutive sentences. Please handwrite them on paper and upload an image of your handwriting. "
 
179
  "If you wish to skip the current text, click 'Skip'."
180
  )
181
 
 
182
  text_box = gr.Textbox(
183
  value=collector.current_text_block,
184
  label="Text to Handwrite",
 
203
  elem_id="regenerate_btn"
204
  )
205
 
 
206
  gr.Markdown("### Step 3: Upload an image of your handwritten version of the text")
207
 
 
208
  upload_info = gr.Markdown(
209
  value="You must be logged in to do this, to help us prevent spam submissions",
210
  elem_id="upload_info"
211
  )
212
 
 
213
  image_input = gr.Image(
214
  type="pil",
215
  label="Upload Handwritten Image",
 
244
  with gr.Row(visible=False) as button_row:
245
  submit_btn = gr.Button("Submit", elem_id="submit_btn")
246
 
247
+ # Update user state when profile changes
248
+ def update_user_state(profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None = None, *args):
249
+ user_state.update_from_profile(profile)
250
+ is_logged_in = user_state.is_logged_in
251
  message = "Please upload your handwritten image of the text below." if is_logged_in else "You must be logged in to do this, to help us prevent spam submissions"
252
 
253
  return {
 
257
  button_row: gr.update(visible=is_logged_in)
258
  }
259
 
260
+ # Load initial state and update UI visibility
261
+ demo.load(update_user_state, inputs=profile_state, outputs=[upload_info, image_input, dataset_options, button_row])
 
 
262
 
263
+ def handle_submit(
264
+ text: str,
265
+ image: Image.Image,
266
+ max_words: int,
267
+ public_checkbox: bool,
268
+ collector: OCRDataCollector | None = None,
269
+ *args
270
+ ):
271
+ """Handle submission with clean parameter order"""
272
+ print(f"Debug - Initial params:")
273
+ print(f"Text: {text[:50]}")
274
+ print(f"Image type: {type(image)}")
275
+ print(f"Max words: {max_words}")
276
+ print(f"Public checkbox: {public_checkbox}")
277
+ print(f"Collector type: {type(collector)}")
278
+
279
+ if collector is None:
280
+ raise gr.Error("Internal error: OCR collector not initialized")
281
+
282
+ if not user_state.is_logged_in:
283
  raise gr.Error("Please log in to use this application")
284
+
285
+ if not isinstance(image, Image.Image):
286
+ raise gr.Error("Please upload a valid image before submitting")
287
 
288
+ # Strip metadata from validated image
289
  stripped_image = strip_metadata(image)
290
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
291
+
292
+ # Define targets based on checkboxes
 
 
 
 
 
 
 
293
  targets = []
294
  if public_checkbox:
295
  targets.append(("public", "rawwerks/handwriting-ocr-all", "_public", False))
296
+ targets.append(("private", f"{user_state.username}/handwriting-ocr-private", "_private", True))
 
297
 
298
+ temp_dir = "temp"
299
+ os.makedirs(temp_dir, exist_ok=True)
300
+
301
  for ds_type, repo_id, suffix, is_private in targets:
302
  try:
303
  collector.hf_api.dataset_info(repo_id)
304
+ except Exception:
305
  collector.hf_api.create_repo(repo_id, repo_type="dataset", private=is_private)
306
 
307
  filename = f"{timestamp}{suffix}.png"
308
  temp_path = os.path.join(temp_dir, filename)
309
  stripped_image.save(temp_path)
310
 
311
+ # Define features to properly handle image files
312
+ features = datasets.Features({
313
+ 'text': datasets.Value('string'),
314
+ 'image': datasets.Image(),
315
+ 'timestamp': datasets.Value('string')
316
+ })
317
+
318
+ try:
319
+ # Try to load existing dataset
320
+ dataset = datasets.load_dataset(repo_id, split="train")
321
+ except Exception:
322
+ # If no existing dataset, create a new empty one
323
+ dataset = datasets.Dataset.from_dict({
324
+ 'text': [],
325
+ 'image': [],
326
+ 'timestamp': []
327
+ }, features=features)
328
 
329
+ # Add the new item
330
+ dataset = dataset.add_item({
331
+ 'text': text,
332
+ 'image': temp_path,
333
+ 'timestamp': timestamp
 
334
  })
335
 
336
+ # Push updates to hub
337
+ dataset.push_to_hub(repo_id, split="train")
338
+ os.remove(temp_path)
339
+
340
  new_text = collector.get_random_text_block(max_words)
341
  return None, new_text
342
 
343
+ # Submit button click handler with simplified inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  submit_btn.click(
345
  fn=handle_submit,
346
  inputs=[
347
+ text_box, # Text to handwrite
348
+ image_input, # Uploaded image
349
+ max_words_slider, # Max words
350
+ public_checkbox, # Public dataset option
351
+ gr.State(collector) # Pass the collector instance
352
  ],
353
  outputs=[image_input, text_box]
354
  )
355
+
356
+ def handle_regenerate(text, max_words):
357
+ # Allow anyone to regenerate text regardless of login status.
358
+ return collector.get_random_text_block(max_words)
359
+
360
  regenerate_btn.click(
361
  fn=handle_regenerate,
362
+ inputs=[text_box, max_words_slider],
363
  outputs=text_box
364
  )
365
 
requirements.txt CHANGED
@@ -5,4 +5,5 @@ pytest>=7.0.0
5
  pytest-playwright>=0.4.0
6
  pytest-asyncio>=0.23.0
7
  playwright>=1.40.0
8
- datasets>=2.16.0
 
 
5
  pytest-playwright>=0.4.0
6
  pytest-asyncio>=0.23.0
7
  playwright>=1.40.0
8
+ datasets>=2.16.0
9
+ pydantic>=2.6.1
test_app.py CHANGED
@@ -10,7 +10,7 @@ def collector():
10
 
11
  def test_get_random_text_block(collector):
12
  # Test that we get a non-empty string
13
- text_block = collector.get_random_text_block()
14
  assert isinstance(text_block, str)
15
  assert len(text_block) > 0
16
 
@@ -18,13 +18,13 @@ def test_get_random_text_block(collector):
18
  assert any(sentence in text_block for sentence in sentences)
19
 
20
  # Test that we get different blocks (probabilistic, but very likely)
21
- blocks = [collector.get_random_text_block() for _ in range(5)]
22
  assert len(set(blocks)) > 1, "Random blocks should be different"
23
 
24
  def test_skip_text(collector):
25
  # Test that we get a different text block when skipping
26
- current_text = collector.get_random_text_block()
27
- new_text = collector.get_random_text_block()
28
 
29
  assert isinstance(new_text, str)
30
  assert len(new_text) > 0
@@ -39,7 +39,7 @@ def test_submit_image(collector):
39
  test_image = Image.fromarray(img_array)
40
 
41
  # Test the current text block
42
- current_text = collector.get_random_text_block()
43
 
44
  # Test submission with valid image
45
  new_text = collector.submit_image(test_image, current_text)
 
10
 
11
  def test_get_random_text_block(collector):
12
  # Test that we get a non-empty string
13
+ text_block = collector.get_random_text_block(max_words=50)
14
  assert isinstance(text_block, str)
15
  assert len(text_block) > 0
16
 
 
18
  assert any(sentence in text_block for sentence in sentences)
19
 
20
  # Test that we get different blocks (probabilistic, but very likely)
21
+ blocks = [collector.get_random_text_block(max_words=50) for _ in range(5)]
22
  assert len(set(blocks)) > 1, "Random blocks should be different"
23
 
24
  def test_skip_text(collector):
25
  # Test that we get a different text block when skipping
26
+ current_text = collector.get_random_text_block(max_words=50)
27
+ new_text = collector.get_random_text_block(max_words=50)
28
 
29
  assert isinstance(new_text, str)
30
  assert len(new_text) > 0
 
39
  test_image = Image.fromarray(img_array)
40
 
41
  # Test the current text block
42
+ current_text = collector.get_random_text_block(max_words=50)
43
 
44
  # Test submission with valid image
45
  new_text = collector.submit_image(test_image, current_text)