taellinglin commited on
Commit
69c9fd9
·
verified ·
1 Parent(s): d850478

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -12
app.py CHANGED
@@ -13,6 +13,8 @@ from torchvision.transforms.functional import to_pil_image
13
  import matplotlib.pyplot as plt
14
  import math
15
  from datetime import datetime
 
 
16
  # --------- Globals --------- #
17
  CHARS = string.ascii_letters + string.digits + string.punctuation
18
  CHAR2IDX = {c: i + 1 for i, c in enumerate(CHARS)}
@@ -87,6 +89,9 @@ class OCRModel(nn.Module):
87
  return x
88
 
89
 
 
 
 
90
  def greedy_decode(log_probs):
91
  # log_probs shape: (T, B, C)
92
  # Usually, B=1 during inference
@@ -255,7 +260,6 @@ PADDING = 8
255
  LABEL_DIR = "./labels"
256
 
257
  def generate_labels(font_file=None, num_labels: int = 25):
258
- import time
259
  global font_path
260
 
261
  try:
@@ -264,7 +268,6 @@ def generate_labels(font_file=None, num_labels: int = 25):
264
  with open(font_file.name, "rb") as uploaded:
265
  with open(font_path, "wb") as f:
266
  f.write(uploaded.read())
267
-
268
  if font_path is None or not os.path.exists(font_path):
269
  font = ImageFont.load_default()
270
  else:
@@ -275,27 +278,26 @@ def generate_labels(font_file=None, num_labels: int = 25):
275
  images = []
276
 
277
  for label in labels:
278
- # Measure text size
279
  bbox = font.getbbox(label)
280
  text_w = bbox[2] - bbox[0]
281
  text_h = bbox[3] - bbox[1]
282
-
283
- # Add 8px padding
284
  pad = 8
285
  img_w = text_w + pad * 2
286
  img_h = text_h + pad * 2
287
 
 
288
  img = Image.new("L", (img_w, img_h), color=255)
289
  draw = ImageDraw.Draw(img)
290
  draw.text((pad, pad), label, font=font, fill=0)
291
 
292
- # Save image to label-specific folder
293
- safe_label = ''.join(c if c.isalnum() else '_' for c in label)
294
- label_dir = f"./labels/{safe_label}"
 
295
  os.makedirs(label_dir, exist_ok=True)
296
 
297
- timestamp = time.strftime("%Y%m%d%H%M%S%f")
298
- filepath = os.path.join(label_dir, f"_{timestamp}.png")
299
  img.save(filepath)
300
 
301
  images.append(img)
@@ -309,10 +311,23 @@ def generate_labels(font_file=None, num_labels: int = 25):
309
  draw.text((10, 50), f"Error: {str(e)}", fill=(255, 0, 0))
310
  return [error_img]
311
 
312
-
 
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
  # --------- Updated Gradio UI with new tab --------- #
315
- with gr.Blocks() as demo:
316
  with gr.Tab("1. Upload Font & Train"):
317
  font_file = gr.File(label="Upload .ttf or .otf font", file_types=[".ttf", ".otf"])
318
  epochs_input = gr.Slider(minimum=1, maximum=4096, value=256, step=1, label="Epochs")
 
13
  import matplotlib.pyplot as plt
14
  import math
15
  from datetime import datetime
16
+ import re
17
+
18
  # --------- Globals --------- #
19
  CHARS = string.ascii_letters + string.digits + string.punctuation
20
  CHAR2IDX = {c: i + 1 for i, c in enumerate(CHARS)}
 
89
  return x
90
 
91
 
92
+ def sanitize_filename(name):
93
+ return re.sub(r'[^a-zA-Z0-9_-]', '_', name)
94
+
95
  def greedy_decode(log_probs):
96
  # log_probs shape: (T, B, C)
97
  # Usually, B=1 during inference
 
260
  LABEL_DIR = "./labels"
261
 
262
  def generate_labels(font_file=None, num_labels: int = 25):
 
263
  global font_path
264
 
265
  try:
 
268
  with open(font_file.name, "rb") as uploaded:
269
  with open(font_path, "wb") as f:
270
  f.write(uploaded.read())
 
271
  if font_path is None or not os.path.exists(font_path):
272
  font = ImageFont.load_default()
273
  else:
 
278
  images = []
279
 
280
  for label in labels:
281
+ # Measure text size and calculate padded image dimensions
282
  bbox = font.getbbox(label)
283
  text_w = bbox[2] - bbox[0]
284
  text_h = bbox[3] - bbox[1]
 
 
285
  pad = 8
286
  img_w = text_w + pad * 2
287
  img_h = text_h + pad * 2
288
 
289
+ # Create image and draw text
290
  img = Image.new("L", (img_w, img_h), color=255)
291
  draw = ImageDraw.Draw(img)
292
  draw.text((pad, pad), label, font=font, fill=0)
293
 
294
+ # Save to ./labels/sanitized_label/timestamp.png
295
+ safe_label = sanitize_filename(label)
296
+ timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f")
297
+ label_dir = os.path.join("./labels", safe_label)
298
  os.makedirs(label_dir, exist_ok=True)
299
 
300
+ filepath = os.path.join(label_dir, f"{timestamp}.png")
 
301
  img.save(filepath)
302
 
303
  images.append(img)
 
311
  draw.text((10, 50), f"Error: {str(e)}", fill=(255, 0, 0))
312
  return [error_img]
313
 
314
+ custom_css = """
315
+ #label-gallery .gallery-item img {
316
+ height: 43px; /* 32pt ≈ 43px */
317
+ width: auto;
318
+ object-fit: contain;
319
+ padding: 4px;
320
+ }
321
+
322
+ #label-gallery {
323
+ flex-grow: 1;
324
+ overflow-y: auto;
325
+ height: 100%;
326
+ }
327
+ """
328
 
329
  # --------- Updated Gradio UI with new tab --------- #
330
+ with gr.Blocks(css=custom_css) as demo:
331
  with gr.Tab("1. Upload Font & Train"):
332
  font_file = gr.File(label="Upload .ttf or .otf font", file_types=[".ttf", ".otf"])
333
  epochs_input = gr.Slider(minimum=1, maximum=4096, value=256, step=1, label="Epochs")