Update app.py
Browse files
app.py
CHANGED
@@ -146,6 +146,7 @@ def load_model(path):
|
|
146 |
|
147 |
# --------- Gradio Functions --------- #
|
148 |
def train_model(font_file, epochs=100, learning_rate=0.001):
|
|
|
149 |
global font_path, ocr_model
|
150 |
|
151 |
# Save the uploaded font file
|
@@ -159,41 +160,54 @@ def train_model(font_file, epochs=100, learning_rate=0.001):
|
|
159 |
dataset = OCRDataset(font_path)
|
160 |
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=custom_collate_fn)
|
161 |
|
162 |
-
# Visualize one sample
|
163 |
img, label = dataset[0]
|
164 |
print("Label:", ''.join([IDX2CHAR[i.item()] for i in label]))
|
165 |
plt.imshow(img.permute(1, 2, 0).squeeze(), cmap='gray')
|
166 |
plt.show()
|
167 |
|
168 |
-
# Initialize model
|
169 |
model = OCRModel(num_classes=len(CHAR2IDX)).to(device)
|
170 |
-
criterion = nn.CTCLoss(blank=
|
171 |
-
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
|
|
172 |
|
173 |
# Training loop
|
174 |
for epoch in range(epochs):
|
|
|
|
|
|
|
175 |
for img, targets, target_lengths in dataloader:
|
176 |
img = img.to(device)
|
177 |
targets = targets.to(device)
|
178 |
target_lengths = target_lengths.to(device)
|
179 |
|
180 |
-
output = model(img)
|
181 |
-
batch_size = img.size(0)
|
182 |
seq_len = output.size(1)
|
183 |
-
|
|
|
|
|
|
|
|
|
184 |
|
185 |
-
loss = criterion(output.log_softmax(2).transpose(0, 1), targets, input_lengths, target_lengths)
|
186 |
optimizer.zero_grad()
|
187 |
loss.backward()
|
188 |
optimizer.step()
|
189 |
|
190 |
-
|
191 |
|
192 |
-
|
193 |
-
|
|
|
|
|
|
|
|
|
|
|
194 |
save_model(model, model_name)
|
195 |
ocr_model = model
|
196 |
-
|
|
|
|
|
197 |
|
198 |
|
199 |
|
|
|
146 |
|
147 |
# --------- Gradio Functions --------- #
|
148 |
def train_model(font_file, epochs=100, learning_rate=0.001):
|
149 |
+
import time
|
150 |
global font_path, ocr_model
|
151 |
|
152 |
# Save the uploaded font file
|
|
|
160 |
dataset = OCRDataset(font_path)
|
161 |
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=custom_collate_fn)
|
162 |
|
163 |
+
# Visualize one sample
|
164 |
img, label = dataset[0]
|
165 |
print("Label:", ''.join([IDX2CHAR[i.item()] for i in label]))
|
166 |
plt.imshow(img.permute(1, 2, 0).squeeze(), cmap='gray')
|
167 |
plt.show()
|
168 |
|
169 |
+
# Initialize model, loss, optimizer, scheduler
|
170 |
model = OCRModel(num_classes=len(CHAR2IDX)).to(device)
|
171 |
+
criterion = nn.CTCLoss(blank=BLANK_IDX)
|
172 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
173 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
|
174 |
|
175 |
# Training loop
|
176 |
for epoch in range(epochs):
|
177 |
+
model.train()
|
178 |
+
running_loss = 0.0
|
179 |
+
|
180 |
for img, targets, target_lengths in dataloader:
|
181 |
img = img.to(device)
|
182 |
targets = targets.to(device)
|
183 |
target_lengths = target_lengths.to(device)
|
184 |
|
185 |
+
output = model(img) # [B, T, C]
|
|
|
186 |
seq_len = output.size(1)
|
187 |
+
batch_size = img.size(0)
|
188 |
+
input_lengths = torch.full((batch_size,), seq_len, dtype=torch.long).to(device)
|
189 |
+
|
190 |
+
log_probs = output.log_softmax(2).transpose(0, 1) # [T, B, C]
|
191 |
+
loss = criterion(log_probs, targets, input_lengths, target_lengths)
|
192 |
|
|
|
193 |
optimizer.zero_grad()
|
194 |
loss.backward()
|
195 |
optimizer.step()
|
196 |
|
197 |
+
running_loss += loss.item()
|
198 |
|
199 |
+
avg_loss = running_loss / len(dataloader)
|
200 |
+
scheduler.step(avg_loss)
|
201 |
+
print(f"[{epoch + 1}/{epochs}] Loss: {avg_loss:.4f}")
|
202 |
+
|
203 |
+
# Save the trained model
|
204 |
+
timestamp = time.strftime("%Y%m%d%H%M%S")
|
205 |
+
model_name = f"{font_name}_{epochs}ep_lr{learning_rate:.0e}_{timestamp}.pth"
|
206 |
save_model(model, model_name)
|
207 |
ocr_model = model
|
208 |
+
|
209 |
+
return f"✅ Training complete! Model saved as '{model_name}'"
|
210 |
+
|
211 |
|
212 |
|
213 |
|