taellinglin commited on
Commit
635d31b
·
verified ·
1 Parent(s): f9b53d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -24
app.py CHANGED
@@ -41,28 +41,42 @@ decoder = BeamSearchDecoderCTC(alphabet)
41
 
42
  # --------- Dataset --------- #
43
  class OCRDataset(Dataset):
44
- def __init__(self, font_path, size=1000):
45
  self.font = ImageFont.truetype(font_path, 32)
46
- self.samples = ["".join(np.random.choice(list(CHARS), np.random.randint(4, 7)))
47
- for _ in range(size)]
 
 
 
48
 
49
  self.transform = transforms.Compose([
50
- transforms.Grayscale(),
 
51
  transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
52
- transforms.ToTensor(),
53
- transforms.Normalize((0.5,), (0.5,))
54
  ])
55
-
56
  def __len__(self):
57
  return len(self.samples)
58
 
59
  def __getitem__(self, idx):
60
- text = self.samples[idx]
61
- img = self.render_text(text)
62
- img = self.transform(img) # convert PIL to tensor with normalization
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- label = torch.tensor([CHAR2IDX[c] for c in text], dtype=torch.long)
65
- return img, label
66
 
67
 
68
  def render_text(self, text):
@@ -125,7 +139,7 @@ def greedy_decode(log_probs):
125
 
126
  # --------- Custom Collate --------- #
127
  def custom_collate_fn(batch):
128
- images, labels = zip(*batch)
129
  images = torch.stack(images, 0)
130
 
131
  flat_labels = []
@@ -163,34 +177,54 @@ def train_model(font_file, epochs=100, learning_rate=0.001):
163
  import time
164
  global font_path, ocr_model
165
 
166
- # Save the uploaded font file
167
  font_name = os.path.splitext(os.path.basename(font_file.name))[0]
168
  font_path = f"./{font_name}.ttf"
169
  with open(font_file.name, "rb") as uploaded:
170
  with open(font_path, "wb") as f:
171
  f.write(uploaded.read())
172
 
173
- # Load dataset
174
- dataset = OCRDataset(font_path)
175
- dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=custom_collate_fn)
 
 
 
 
 
 
 
 
 
 
176
 
177
- # Visualize one sample
178
- img, label = dataset[0]
179
  print("Label:", ''.join([IDX2CHAR[i.item()] for i in label]))
180
  plt.imshow(img.permute(1, 2, 0).squeeze(), cmap='gray')
181
  plt.show()
182
 
183
- # Initialize model, loss, optimizer, scheduler
184
  model = OCRModel(num_classes=len(CHAR2IDX)).to(device)
185
  criterion = nn.CTCLoss(blank=BLANK_IDX)
186
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
187
  scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
188
 
189
- # Training loop
190
  for epoch in range(epochs):
 
 
 
 
191
  model.train()
192
  running_loss = 0.0
193
 
 
 
 
 
 
 
 
 
 
194
  for img, targets, target_lengths in dataloader:
195
  img = img.to(device)
196
  targets = targets.to(device)
@@ -214,7 +248,7 @@ def train_model(font_file, epochs=100, learning_rate=0.001):
214
  scheduler.step(avg_loss)
215
  print(f"[{epoch + 1}/{epochs}] Loss: {avg_loss:.4f}")
216
 
217
- # Save the trained model
218
  timestamp = time.strftime("%Y%m%d%H%M%S")
219
  model_name = f"{font_name}_{epochs}ep_lr{learning_rate:.0e}_{timestamp}.pth"
220
  save_model(model, model_name)
@@ -226,6 +260,7 @@ def train_model(font_file, epochs=100, learning_rate=0.001):
226
 
227
 
228
 
 
229
  def preprocess_image(image: Image.Image):
230
  img_cv = np.array(image.convert("L"))
231
 
@@ -300,7 +335,11 @@ def predict_text(image: Image.Image, ground_truth: str = None, debug: bool = Fal
300
  output = ocr_model(img_tensor) # (1, T, C)
301
  log_probs = output.log_softmax(2)[0] # (T, C)
302
 
303
- pred_text = decoder.decode(log_probs.cpu().numpy()) # Best beam path
 
 
 
 
304
 
305
  # Confidence: mean max prob per timestep
306
  probs = log_probs.exp()
@@ -322,7 +361,8 @@ def predict_text(image: Image.Image, ground_truth: str = None, debug: bool = Fal
322
  if ground_truth:
323
  print("Ground Truth:", ground_truth)
324
 
325
- return f"<strong>Prediction:</strong> {pretty_output}<br><strong>Confidence:</strong> {avg_conf:.2%}{sim_score}"
 
326
 
327
 
328
  # New helper function: generate label images grid
 
41
 
42
  # --------- Dataset --------- #
43
  class OCRDataset(Dataset):
44
+ def __init__(self, font_path, size=1000, label_length_range=(4, 7)):
45
  self.font = ImageFont.truetype(font_path, 32)
46
+ self.label_length_range = label_length_range
47
+ self.samples = [
48
+ "".join(np.random.choice(list(CHARS), np.random.randint(*self.label_length_range)))
49
+ for _ in range(size)
50
+ ]
51
 
52
  self.transform = transforms.Compose([
53
+ transforms.ToTensor(), # must be first
54
+ transforms.Normalize((0.5,), (0.5,)),
55
  transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
56
+ transforms.RandomApply([transforms.GaussianBlur(kernel_size=3)], p=0.3),
57
+ transforms.RandomApply([transforms.RandomAffine(degrees=10, translate=(0.1, 0.1))], p=0.3),
58
  ])
 
59
  def __len__(self):
60
  return len(self.samples)
61
 
62
  def __getitem__(self, idx):
63
+ label = self.samples[idx]
64
+
65
+ # Create an image with padding
66
+ pad = 8
67
+ w = self.font.getlength(label)
68
+ h = self.font.size
69
+ img_w, img_h = int(w + 2 * pad), int(h + 2 * pad)
70
+ img = Image.new("L", (img_w, img_h), 255)
71
+ draw = ImageDraw.Draw(img)
72
+ draw.text((pad, pad), label, font=self.font, fill=0)
73
+
74
+ img = self.transform(img)
75
+ label_encoded = torch.tensor([CHAR2IDX[c] for c in label], dtype=torch.long)
76
+ label_length = torch.tensor(len(label_encoded), dtype=torch.long)
77
+
78
+ return img, label_encoded, label_length
79
 
 
 
80
 
81
 
82
  def render_text(self, text):
 
139
 
140
  # --------- Custom Collate --------- #
141
  def custom_collate_fn(batch):
142
+ images, labels, _ = zip(*batch)
143
  images = torch.stack(images, 0)
144
 
145
  flat_labels = []
 
177
  import time
178
  global font_path, ocr_model
179
 
180
+ # Save uploaded font
181
  font_name = os.path.splitext(os.path.basename(font_file.name))[0]
182
  font_path = f"./{font_name}.ttf"
183
  with open(font_file.name, "rb") as uploaded:
184
  with open(font_path, "wb") as f:
185
  f.write(uploaded.read())
186
 
187
+ # Curriculum learning: Start with shorter labels, increase over time
188
+ def get_dataset_for_epoch(epoch):
189
+ if epoch < epochs // 3:
190
+ label_len = (3, 4)
191
+ elif epoch < 2 * epochs // 3:
192
+ label_len = (4, 6)
193
+ else:
194
+ label_len = (5, 7)
195
+ return OCRDataset(font_path, label_length_range=label_len)
196
+
197
+ # Visualize one sample from initial dataset
198
+ dataset = get_dataset_for_epoch(0)
199
+ img, label, _ = dataset[0] # Ignore the 3rd value (e.g., label length)
200
 
 
 
201
  print("Label:", ''.join([IDX2CHAR[i.item()] for i in label]))
202
  plt.imshow(img.permute(1, 2, 0).squeeze(), cmap='gray')
203
  plt.show()
204
 
205
+ # Init model (ensure BiLSTM)
206
  model = OCRModel(num_classes=len(CHAR2IDX)).to(device)
207
  criterion = nn.CTCLoss(blank=BLANK_IDX)
208
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
209
  scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
210
 
 
211
  for epoch in range(epochs):
212
+ # Load new dataset for current curriculum stage
213
+ dataset = get_dataset_for_epoch(epoch)
214
+ dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=custom_collate_fn)
215
+
216
  model.train()
217
  running_loss = 0.0
218
 
219
+ # CTC warmup: reduced LR during initial epochs
220
+ if epoch < 5:
221
+ warmup_lr = learning_rate * 0.2
222
+ for param_group in optimizer.param_groups:
223
+ param_group['lr'] = warmup_lr
224
+ else:
225
+ for param_group in optimizer.param_groups:
226
+ param_group['lr'] = learning_rate
227
+
228
  for img, targets, target_lengths in dataloader:
229
  img = img.to(device)
230
  targets = targets.to(device)
 
248
  scheduler.step(avg_loss)
249
  print(f"[{epoch + 1}/{epochs}] Loss: {avg_loss:.4f}")
250
 
251
+ # Save the model
252
  timestamp = time.strftime("%Y%m%d%H%M%S")
253
  model_name = f"{font_name}_{epochs}ep_lr{learning_rate:.0e}_{timestamp}.pth"
254
  save_model(model, model_name)
 
260
 
261
 
262
 
263
+
264
  def preprocess_image(image: Image.Image):
265
  img_cv = np.array(image.convert("L"))
266
 
 
335
  output = ocr_model(img_tensor) # (1, T, C)
336
  log_probs = output.log_softmax(2)[0] # (T, C)
337
 
338
+ # Decode best beam path (string)
339
+ pred_text_raw = decoder.decode(log_probs.cpu().numpy())
340
+ pred_chars = pred_text_raw.replace("<BLANK>", "")
341
+ # Remove <BLANK> tokens if present (assuming <BLANK> is in vocab)
342
+ pred_text = ''.join([c for c in pred_chars if c != "<BLANK>"])
343
 
344
  # Confidence: mean max prob per timestep
345
  probs = log_probs.exp()
 
361
  if ground_truth:
362
  print("Ground Truth:", ground_truth)
363
 
364
+ return f"<strong>Prediction:</strong> <strong>{pretty_output}</strong><br><strong>Confidence:</strong> {avg_conf:.2%}{sim_score}"
365
+
366
 
367
 
368
  # New helper function: generate label images grid