taellinglin commited on
Commit
3ed4400
·
verified ·
1 Parent(s): d096f8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -41
app.py CHANGED
@@ -38,6 +38,10 @@ alphabet = Alphabet.build_alphabet(labels)
38
 
39
  # Now initialize decoder correctly
40
  decoder = BeamSearchDecoderCTC(alphabet)
 
 
 
 
41
 
42
  # --------- Dataset --------- #
43
  class OCRDataset(Dataset):
@@ -155,7 +159,11 @@ def custom_collate_fn(batch):
155
 
156
  # --------- Model Save/Load --------- #
157
  def list_saved_models():
158
- return [f for f in os.listdir() if f.endswith(".pth")]
 
 
 
 
159
 
160
 
161
  def save_model(model, path):
@@ -177,14 +185,18 @@ def train_model(font_file, epochs=100, learning_rate=0.001):
177
  import time
178
  global font_path, ocr_model
179
 
180
- # Save uploaded font
 
 
 
 
181
  font_name = os.path.splitext(os.path.basename(font_file.name))[0]
182
- font_path = f"./{font_name}.ttf"
183
  with open(font_file.name, "rb") as uploaded:
184
  with open(font_path, "wb") as f:
185
  f.write(uploaded.read())
186
 
187
- # Curriculum learning: Start with shorter labels, increase over time
188
  def get_dataset_for_epoch(epoch):
189
  if epoch < epochs // 3:
190
  label_len = (3, 4)
@@ -194,29 +206,27 @@ def train_model(font_file, epochs=100, learning_rate=0.001):
194
  label_len = (5, 7)
195
  return OCRDataset(font_path, label_length_range=label_len)
196
 
197
- # Visualize one sample from initial dataset
198
  dataset = get_dataset_for_epoch(0)
199
- img, label, _ = dataset[0] # Ignore the 3rd value (e.g., label length)
200
-
201
  print("Label:", ''.join([IDX2CHAR[i.item()] for i in label]))
202
  plt.imshow(img.permute(1, 2, 0).squeeze(), cmap='gray')
203
  plt.show()
204
 
205
- # Init model (ensure BiLSTM)
206
  model = OCRModel(num_classes=len(CHAR2IDX)).to(device)
207
  criterion = nn.CTCLoss(blank=BLANK_IDX)
208
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
209
  scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
210
 
211
  for epoch in range(epochs):
212
- # Load new dataset for current curriculum stage
213
  dataset = get_dataset_for_epoch(epoch)
214
  dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=custom_collate_fn)
215
 
216
  model.train()
217
  running_loss = 0.0
218
 
219
- # CTC warmup: reduced LR during initial epochs
220
  if epoch < 5:
221
  warmup_lr = learning_rate * 0.2
222
  for param_group in optimizer.param_groups:
@@ -230,12 +240,12 @@ def train_model(font_file, epochs=100, learning_rate=0.001):
230
  targets = targets.to(device)
231
  target_lengths = target_lengths.to(device)
232
 
233
- output = model(img) # [B, T, C]
234
  seq_len = output.size(1)
235
  batch_size = img.size(0)
236
  input_lengths = torch.full((batch_size,), seq_len, dtype=torch.long).to(device)
237
 
238
- log_probs = output.log_softmax(2).transpose(0, 1) # [T, B, C]
239
  loss = criterion(log_probs, targets, input_lengths, target_lengths)
240
 
241
  optimizer.zero_grad()
@@ -248,13 +258,15 @@ def train_model(font_file, epochs=100, learning_rate=0.001):
248
  scheduler.step(avg_loss)
249
  print(f"[{epoch + 1}/{epochs}] Loss: {avg_loss:.4f}")
250
 
251
- # Save the model
252
  timestamp = time.strftime("%Y%m%d%H%M%S")
253
  model_name = f"{font_name}_{epochs}ep_lr{learning_rate:.0e}_{timestamp}.pth"
254
- save_model(model, model_name)
 
 
255
  ocr_model = model
 
256
 
257
- return f"✅ Training complete! Model saved as '{model_name}'"
258
 
259
 
260
 
@@ -376,11 +388,11 @@ def generate_labels(font_file=None, num_labels: int = 25):
376
  global font_path
377
 
378
  try:
379
- if font_file:
380
- font_path = "./temp_font_labels.ttf"
381
- with open(font_file.name, "rb") as uploaded:
382
- with open(font_path, "wb") as f:
383
- f.write(uploaded.read())
384
  if font_path is None or not os.path.exists(font_path):
385
  font = ImageFont.load_default()
386
  else:
@@ -391,7 +403,6 @@ def generate_labels(font_file=None, num_labels: int = 25):
391
  images = []
392
 
393
  for label in labels:
394
- # Measure text size and calculate padded image dimensions
395
  bbox = font.getbbox(label)
396
  text_w = bbox[2] - bbox[0]
397
  text_h = bbox[3] - bbox[1]
@@ -399,12 +410,10 @@ def generate_labels(font_file=None, num_labels: int = 25):
399
  img_w = text_w + pad * 2
400
  img_h = text_h + pad * 2
401
 
402
- # Create image and draw text
403
  img = Image.new("L", (img_w, img_h), color=255)
404
  draw = ImageDraw.Draw(img)
405
  draw.text((pad, pad), label, font=font, fill=0)
406
 
407
- # Save to ./labels/sanitized_label/timestamp.png
408
  safe_label = sanitize_filename(label)
409
  timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f")
410
  label_dir = os.path.join("./labels", safe_label)
@@ -424,6 +433,10 @@ def generate_labels(font_file=None, num_labels: int = 25):
424
  draw.text((10, 50), f"Error: {str(e)}", fill=(255, 0, 0))
425
  return [error_img]
426
 
 
 
 
 
427
  custom_css = """
428
  #label-gallery .gallery-item img {
429
  height: 43px; /* 32pt ≈ 43px */
@@ -444,7 +457,7 @@ custom_css = """
444
 
445
  # --------- Updated Gradio UI with new tab --------- #
446
  with gr.Blocks(css=custom_css) as demo:
447
- with gr.Tab("1. Upload Font & Train"):
448
  font_file = gr.File(label="Upload .ttf or .otf font", file_types=[".ttf", ".otf"])
449
  epochs_input = gr.Slider(minimum=1, maximum=4096, value=256, step=1, label="Epochs")
450
  lr_input = gr.Slider(minimum=0.001, maximum=0.1, value=0.05, step=0.001, label="Learning Rate")
@@ -453,8 +466,28 @@ with gr.Blocks(css=custom_css) as demo:
453
 
454
  train_button.click(fn=train_model, inputs=[font_file, epochs_input, lr_input], outputs=train_status)
455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
456
 
457
- with gr.Tab("2. Use Trained Model"):
 
458
  model_list = gr.Dropdown(choices=list_saved_models(), label="Select OCR Model")
459
  refresh_btn = gr.Button("🔄 Refresh Models")
460
  load_model_btn = gr.Button("Load Model") # <-- new button
@@ -472,23 +505,7 @@ with gr.Blocks(css=custom_css) as demo:
472
 
473
  predict_btn.click(fn=predict_text, inputs=image_input, outputs=output_text)
474
 
475
- with gr.Tab("3. Generate Labels"):
476
- font_file_labels = gr.File(label="Optional font for label image", file_types=[".ttf", ".otf"])
477
- num_labels = gr.Number(value=20, label="Number of labels to generate", precision=0, interactive=True)
478
- gen_button = gr.Button("Generate Label Grid")
479
-
480
- gen_button.click(
481
- fn=generate_labels,
482
- inputs=[font_file_labels, num_labels],
483
- outputs=gr.Gallery(
484
- label="Generated Labels",
485
- columns=16, # 16 tiles per row
486
- object_fit="contain", # Maintain aspect ratio
487
- height="100%", # Allow full app height
488
- elem_id="label-gallery" # For CSS targeting
489
- )
490
 
491
- )
492
 
493
 
494
 
 
38
 
39
  # Now initialize decoder correctly
40
  decoder = BeamSearchDecoderCTC(alphabet)
41
+ # Ensure required directories exist at startup
42
+ os.makedirs("./fonts", exist_ok=True)
43
+ os.makedirs("./models", exist_ok=True)
44
+ os.makedirs("./labels", exist_ok=True)
45
 
46
  # --------- Dataset --------- #
47
  class OCRDataset(Dataset):
 
159
 
160
  # --------- Model Save/Load --------- #
161
  def list_saved_models():
162
+ model_dir = "./models"
163
+ if not os.path.exists(model_dir):
164
+ return []
165
+ return [f for f in os.listdir(model_dir) if f.endswith(".pth")]
166
+
167
 
168
 
169
  def save_model(model, path):
 
185
  import time
186
  global font_path, ocr_model
187
 
188
+ # Ensure directories exist
189
+ os.makedirs("./fonts", exist_ok=True)
190
+ os.makedirs("./models", exist_ok=True)
191
+
192
+ # Save uploaded font to ./fonts
193
  font_name = os.path.splitext(os.path.basename(font_file.name))[0]
194
+ font_path = f"./fonts/{font_name}.ttf"
195
  with open(font_file.name, "rb") as uploaded:
196
  with open(font_path, "wb") as f:
197
  f.write(uploaded.read())
198
 
199
+ # Curriculum learning: label length grows over time
200
  def get_dataset_for_epoch(epoch):
201
  if epoch < epochs // 3:
202
  label_len = (3, 4)
 
206
  label_len = (5, 7)
207
  return OCRDataset(font_path, label_length_range=label_len)
208
 
209
+ # Visualize one sample
210
  dataset = get_dataset_for_epoch(0)
211
+ img, label, _ = dataset[0]
 
212
  print("Label:", ''.join([IDX2CHAR[i.item()] for i in label]))
213
  plt.imshow(img.permute(1, 2, 0).squeeze(), cmap='gray')
214
  plt.show()
215
 
216
+ # Model setup
217
  model = OCRModel(num_classes=len(CHAR2IDX)).to(device)
218
  criterion = nn.CTCLoss(blank=BLANK_IDX)
219
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
220
  scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
221
 
222
  for epoch in range(epochs):
 
223
  dataset = get_dataset_for_epoch(epoch)
224
  dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=custom_collate_fn)
225
 
226
  model.train()
227
  running_loss = 0.0
228
 
229
+ # Warmup learning rate
230
  if epoch < 5:
231
  warmup_lr = learning_rate * 0.2
232
  for param_group in optimizer.param_groups:
 
240
  targets = targets.to(device)
241
  target_lengths = target_lengths.to(device)
242
 
243
+ output = model(img)
244
  seq_len = output.size(1)
245
  batch_size = img.size(0)
246
  input_lengths = torch.full((batch_size,), seq_len, dtype=torch.long).to(device)
247
 
248
+ log_probs = output.log_softmax(2).transpose(0, 1)
249
  loss = criterion(log_probs, targets, input_lengths, target_lengths)
250
 
251
  optimizer.zero_grad()
 
258
  scheduler.step(avg_loss)
259
  print(f"[{epoch + 1}/{epochs}] Loss: {avg_loss:.4f}")
260
 
261
+ # Save the model to ./models
262
  timestamp = time.strftime("%Y%m%d%H%M%S")
263
  model_name = f"{font_name}_{epochs}ep_lr{learning_rate:.0e}_{timestamp}.pth"
264
+ model_path = os.path.join("./models", model_name)
265
+ save_model(model, model_path)
266
+
267
  ocr_model = model
268
+ return f"✅ Training complete! Model saved as '{model_path}'"
269
 
 
270
 
271
 
272
 
 
388
  global font_path
389
 
390
  try:
391
+ if font_file and font_file != "None":
392
+ font_path = os.path.abspath(font_file)
393
+ else:
394
+ font_path = None
395
+
396
  if font_path is None or not os.path.exists(font_path):
397
  font = ImageFont.load_default()
398
  else:
 
403
  images = []
404
 
405
  for label in labels:
 
406
  bbox = font.getbbox(label)
407
  text_w = bbox[2] - bbox[0]
408
  text_h = bbox[3] - bbox[1]
 
410
  img_w = text_w + pad * 2
411
  img_h = text_h + pad * 2
412
 
 
413
  img = Image.new("L", (img_w, img_h), color=255)
414
  draw = ImageDraw.Draw(img)
415
  draw.text((pad, pad), label, font=font, fill=0)
416
 
 
417
  safe_label = sanitize_filename(label)
418
  timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f")
419
  label_dir = os.path.join("./labels", safe_label)
 
433
  draw.text((10, 50), f"Error: {str(e)}", fill=(255, 0, 0))
434
  return [error_img]
435
 
436
+ def list_fonts():
437
+ fonts = [f for f in os.listdir() if f.lower().endswith((".ttf", ".otf"))]
438
+ return ["None"] + fonts if fonts else ["None"]
439
+
440
  custom_css = """
441
  #label-gallery .gallery-item img {
442
  height: 43px; /* 32pt ≈ 43px */
 
457
 
458
  # --------- Updated Gradio UI with new tab --------- #
459
  with gr.Blocks(css=custom_css) as demo:
460
+ with gr.Tab("【Train OCR Model】"):
461
  font_file = gr.File(label="Upload .ttf or .otf font", file_types=[".ttf", ".otf"])
462
  epochs_input = gr.Slider(minimum=1, maximum=4096, value=256, step=1, label="Epochs")
463
  lr_input = gr.Slider(minimum=0.001, maximum=0.1, value=0.05, step=0.001, label="Learning Rate")
 
466
 
467
  train_button.click(fn=train_model, inputs=[font_file, epochs_input, lr_input], outputs=train_status)
468
 
469
+ with gr.Tab("【Generate Labels】"):
470
+ font_file_labels = gr.Dropdown(
471
+ choices=list_fonts(),
472
+ label="Optional font for label image",
473
+ interactive=True,
474
+ )
475
+ num_labels = gr.Number(value=20, label="Number of labels to generate", precision=0, interactive=True)
476
+ gen_button = gr.Button("Generate Label Grid")
477
+
478
+ gen_button.click(
479
+ fn=generate_labels,
480
+ inputs=[font_file_labels, num_labels],
481
+ outputs=gr.Gallery(
482
+ label="Generated Labels",
483
+ columns=16, # 16 tiles per row
484
+ object_fit="contain", # Maintain aspect ratio
485
+ height="100%", # Allow full app height
486
+ elem_id="label-gallery" # For CSS targeting
487
+ )
488
 
489
+ )
490
+ with gr.Tab("【Recognize Text】"):
491
  model_list = gr.Dropdown(choices=list_saved_models(), label="Select OCR Model")
492
  refresh_btn = gr.Button("🔄 Refresh Models")
493
  load_model_btn = gr.Button("Load Model") # <-- new button
 
505
 
506
  predict_btn.click(fn=predict_text, inputs=image_input, outputs=output_text)
507
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
 
 
509
 
510
 
511