taellinglin commited on
Commit
445a45a
·
verified ·
1 Parent(s): 69c9fd9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -12
app.py CHANGED
@@ -146,6 +146,7 @@ def load_model(path):
146
 
147
  # --------- Gradio Functions --------- #
148
  def train_model(font_file, epochs=100, learning_rate=0.001):
 
149
  global font_path, ocr_model
150
 
151
  # Save the uploaded font file
@@ -159,41 +160,54 @@ def train_model(font_file, epochs=100, learning_rate=0.001):
159
  dataset = OCRDataset(font_path)
160
  dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=custom_collate_fn)
161
 
162
- # Visualize one sample for sanity check
163
  img, label = dataset[0]
164
  print("Label:", ''.join([IDX2CHAR[i.item()] for i in label]))
165
  plt.imshow(img.permute(1, 2, 0).squeeze(), cmap='gray')
166
  plt.show()
167
 
168
- # Initialize model
169
  model = OCRModel(num_classes=len(CHAR2IDX)).to(device)
170
- criterion = nn.CTCLoss(blank=0)
171
- optimizer = optim.Adam(model.parameters(), lr=learning_rate)
 
172
 
173
  # Training loop
174
  for epoch in range(epochs):
 
 
 
175
  for img, targets, target_lengths in dataloader:
176
  img = img.to(device)
177
  targets = targets.to(device)
178
  target_lengths = target_lengths.to(device)
179
 
180
- output = model(img)
181
- batch_size = img.size(0)
182
  seq_len = output.size(1)
183
- input_lengths = torch.full(size=(batch_size,), fill_value=seq_len, dtype=torch.long).to(device)
 
 
 
 
184
 
185
- loss = criterion(output.log_softmax(2).transpose(0, 1), targets, input_lengths, target_lengths)
186
  optimizer.zero_grad()
187
  loss.backward()
188
  optimizer.step()
189
 
190
- print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}")
191
 
192
- # Save model with structured name
193
- model_name = f"{font_name}_{epochs}epochs_lr{learning_rate:.0e}.pth"
 
 
 
 
 
194
  save_model(model, model_name)
195
  ocr_model = model
196
- return f"Training complete! Model saved as '{model_name}'."
 
 
197
 
198
 
199
 
 
146
 
147
  # --------- Gradio Functions --------- #
148
  def train_model(font_file, epochs=100, learning_rate=0.001):
149
+ import time
150
  global font_path, ocr_model
151
 
152
  # Save the uploaded font file
 
160
  dataset = OCRDataset(font_path)
161
  dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=custom_collate_fn)
162
 
163
+ # Visualize one sample
164
  img, label = dataset[0]
165
  print("Label:", ''.join([IDX2CHAR[i.item()] for i in label]))
166
  plt.imshow(img.permute(1, 2, 0).squeeze(), cmap='gray')
167
  plt.show()
168
 
169
+ # Initialize model, loss, optimizer, scheduler
170
  model = OCRModel(num_classes=len(CHAR2IDX)).to(device)
171
+ criterion = nn.CTCLoss(blank=BLANK_IDX)
172
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
173
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
174
 
175
  # Training loop
176
  for epoch in range(epochs):
177
+ model.train()
178
+ running_loss = 0.0
179
+
180
  for img, targets, target_lengths in dataloader:
181
  img = img.to(device)
182
  targets = targets.to(device)
183
  target_lengths = target_lengths.to(device)
184
 
185
+ output = model(img) # [B, T, C]
 
186
  seq_len = output.size(1)
187
+ batch_size = img.size(0)
188
+ input_lengths = torch.full((batch_size,), seq_len, dtype=torch.long).to(device)
189
+
190
+ log_probs = output.log_softmax(2).transpose(0, 1) # [T, B, C]
191
+ loss = criterion(log_probs, targets, input_lengths, target_lengths)
192
 
 
193
  optimizer.zero_grad()
194
  loss.backward()
195
  optimizer.step()
196
 
197
+ running_loss += loss.item()
198
 
199
+ avg_loss = running_loss / len(dataloader)
200
+ scheduler.step(avg_loss)
201
+ print(f"[{epoch + 1}/{epochs}] Loss: {avg_loss:.4f}")
202
+
203
+ # Save the trained model
204
+ timestamp = time.strftime("%Y%m%d%H%M%S")
205
+ model_name = f"{font_name}_{epochs}ep_lr{learning_rate:.0e}_{timestamp}.pth"
206
  save_model(model, model_name)
207
  ocr_model = model
208
+
209
+ return f"✅ Training complete! Model saved as '{model_name}'"
210
+
211
 
212
 
213