Update app.py
Browse files
app.py
CHANGED
@@ -38,6 +38,10 @@ alphabet = Alphabet.build_alphabet(labels)
|
|
38 |
|
39 |
# Now initialize decoder correctly
|
40 |
decoder = BeamSearchDecoderCTC(alphabet)
|
|
|
|
|
|
|
|
|
41 |
|
42 |
# --------- Dataset --------- #
|
43 |
class OCRDataset(Dataset):
|
@@ -155,7 +159,11 @@ def custom_collate_fn(batch):
|
|
155 |
|
156 |
# --------- Model Save/Load --------- #
|
157 |
def list_saved_models():
|
158 |
-
|
|
|
|
|
|
|
|
|
159 |
|
160 |
|
161 |
def save_model(model, path):
|
@@ -177,14 +185,18 @@ def train_model(font_file, epochs=100, learning_rate=0.001):
|
|
177 |
import time
|
178 |
global font_path, ocr_model
|
179 |
|
180 |
-
#
|
|
|
|
|
|
|
|
|
181 |
font_name = os.path.splitext(os.path.basename(font_file.name))[0]
|
182 |
-
font_path = f"./{font_name}.ttf"
|
183 |
with open(font_file.name, "rb") as uploaded:
|
184 |
with open(font_path, "wb") as f:
|
185 |
f.write(uploaded.read())
|
186 |
|
187 |
-
# Curriculum learning:
|
188 |
def get_dataset_for_epoch(epoch):
|
189 |
if epoch < epochs // 3:
|
190 |
label_len = (3, 4)
|
@@ -194,29 +206,27 @@ def train_model(font_file, epochs=100, learning_rate=0.001):
|
|
194 |
label_len = (5, 7)
|
195 |
return OCRDataset(font_path, label_length_range=label_len)
|
196 |
|
197 |
-
# Visualize one sample
|
198 |
dataset = get_dataset_for_epoch(0)
|
199 |
-
img, label, _ = dataset[0]
|
200 |
-
|
201 |
print("Label:", ''.join([IDX2CHAR[i.item()] for i in label]))
|
202 |
plt.imshow(img.permute(1, 2, 0).squeeze(), cmap='gray')
|
203 |
plt.show()
|
204 |
|
205 |
-
#
|
206 |
model = OCRModel(num_classes=len(CHAR2IDX)).to(device)
|
207 |
criterion = nn.CTCLoss(blank=BLANK_IDX)
|
208 |
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
209 |
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
|
210 |
|
211 |
for epoch in range(epochs):
|
212 |
-
# Load new dataset for current curriculum stage
|
213 |
dataset = get_dataset_for_epoch(epoch)
|
214 |
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=custom_collate_fn)
|
215 |
|
216 |
model.train()
|
217 |
running_loss = 0.0
|
218 |
|
219 |
-
#
|
220 |
if epoch < 5:
|
221 |
warmup_lr = learning_rate * 0.2
|
222 |
for param_group in optimizer.param_groups:
|
@@ -230,12 +240,12 @@ def train_model(font_file, epochs=100, learning_rate=0.001):
|
|
230 |
targets = targets.to(device)
|
231 |
target_lengths = target_lengths.to(device)
|
232 |
|
233 |
-
output = model(img)
|
234 |
seq_len = output.size(1)
|
235 |
batch_size = img.size(0)
|
236 |
input_lengths = torch.full((batch_size,), seq_len, dtype=torch.long).to(device)
|
237 |
|
238 |
-
log_probs = output.log_softmax(2).transpose(0, 1)
|
239 |
loss = criterion(log_probs, targets, input_lengths, target_lengths)
|
240 |
|
241 |
optimizer.zero_grad()
|
@@ -248,13 +258,15 @@ def train_model(font_file, epochs=100, learning_rate=0.001):
|
|
248 |
scheduler.step(avg_loss)
|
249 |
print(f"[{epoch + 1}/{epochs}] Loss: {avg_loss:.4f}")
|
250 |
|
251 |
-
# Save the model
|
252 |
timestamp = time.strftime("%Y%m%d%H%M%S")
|
253 |
model_name = f"{font_name}_{epochs}ep_lr{learning_rate:.0e}_{timestamp}.pth"
|
254 |
-
|
|
|
|
|
255 |
ocr_model = model
|
|
|
256 |
|
257 |
-
return f"✅ Training complete! Model saved as '{model_name}'"
|
258 |
|
259 |
|
260 |
|
@@ -376,11 +388,11 @@ def generate_labels(font_file=None, num_labels: int = 25):
|
|
376 |
global font_path
|
377 |
|
378 |
try:
|
379 |
-
if font_file:
|
380 |
-
font_path =
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
if font_path is None or not os.path.exists(font_path):
|
385 |
font = ImageFont.load_default()
|
386 |
else:
|
@@ -391,7 +403,6 @@ def generate_labels(font_file=None, num_labels: int = 25):
|
|
391 |
images = []
|
392 |
|
393 |
for label in labels:
|
394 |
-
# Measure text size and calculate padded image dimensions
|
395 |
bbox = font.getbbox(label)
|
396 |
text_w = bbox[2] - bbox[0]
|
397 |
text_h = bbox[3] - bbox[1]
|
@@ -399,12 +410,10 @@ def generate_labels(font_file=None, num_labels: int = 25):
|
|
399 |
img_w = text_w + pad * 2
|
400 |
img_h = text_h + pad * 2
|
401 |
|
402 |
-
# Create image and draw text
|
403 |
img = Image.new("L", (img_w, img_h), color=255)
|
404 |
draw = ImageDraw.Draw(img)
|
405 |
draw.text((pad, pad), label, font=font, fill=0)
|
406 |
|
407 |
-
# Save to ./labels/sanitized_label/timestamp.png
|
408 |
safe_label = sanitize_filename(label)
|
409 |
timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f")
|
410 |
label_dir = os.path.join("./labels", safe_label)
|
@@ -424,6 +433,10 @@ def generate_labels(font_file=None, num_labels: int = 25):
|
|
424 |
draw.text((10, 50), f"Error: {str(e)}", fill=(255, 0, 0))
|
425 |
return [error_img]
|
426 |
|
|
|
|
|
|
|
|
|
427 |
custom_css = """
|
428 |
#label-gallery .gallery-item img {
|
429 |
height: 43px; /* 32pt ≈ 43px */
|
@@ -444,7 +457,7 @@ custom_css = """
|
|
444 |
|
445 |
# --------- Updated Gradio UI with new tab --------- #
|
446 |
with gr.Blocks(css=custom_css) as demo:
|
447 |
-
with gr.Tab("
|
448 |
font_file = gr.File(label="Upload .ttf or .otf font", file_types=[".ttf", ".otf"])
|
449 |
epochs_input = gr.Slider(minimum=1, maximum=4096, value=256, step=1, label="Epochs")
|
450 |
lr_input = gr.Slider(minimum=0.001, maximum=0.1, value=0.05, step=0.001, label="Learning Rate")
|
@@ -453,8 +466,28 @@ with gr.Blocks(css=custom_css) as demo:
|
|
453 |
|
454 |
train_button.click(fn=train_model, inputs=[font_file, epochs_input, lr_input], outputs=train_status)
|
455 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
456 |
|
457 |
-
|
|
|
458 |
model_list = gr.Dropdown(choices=list_saved_models(), label="Select OCR Model")
|
459 |
refresh_btn = gr.Button("🔄 Refresh Models")
|
460 |
load_model_btn = gr.Button("Load Model") # <-- new button
|
@@ -472,23 +505,7 @@ with gr.Blocks(css=custom_css) as demo:
|
|
472 |
|
473 |
predict_btn.click(fn=predict_text, inputs=image_input, outputs=output_text)
|
474 |
|
475 |
-
with gr.Tab("3. Generate Labels"):
|
476 |
-
font_file_labels = gr.File(label="Optional font for label image", file_types=[".ttf", ".otf"])
|
477 |
-
num_labels = gr.Number(value=20, label="Number of labels to generate", precision=0, interactive=True)
|
478 |
-
gen_button = gr.Button("Generate Label Grid")
|
479 |
-
|
480 |
-
gen_button.click(
|
481 |
-
fn=generate_labels,
|
482 |
-
inputs=[font_file_labels, num_labels],
|
483 |
-
outputs=gr.Gallery(
|
484 |
-
label="Generated Labels",
|
485 |
-
columns=16, # 16 tiles per row
|
486 |
-
object_fit="contain", # Maintain aspect ratio
|
487 |
-
height="100%", # Allow full app height
|
488 |
-
elem_id="label-gallery" # For CSS targeting
|
489 |
-
)
|
490 |
|
491 |
-
)
|
492 |
|
493 |
|
494 |
|
|
|
38 |
|
39 |
# Now initialize decoder correctly
|
40 |
decoder = BeamSearchDecoderCTC(alphabet)
|
41 |
+
# Ensure required directories exist at startup
|
42 |
+
os.makedirs("./fonts", exist_ok=True)
|
43 |
+
os.makedirs("./models", exist_ok=True)
|
44 |
+
os.makedirs("./labels", exist_ok=True)
|
45 |
|
46 |
# --------- Dataset --------- #
|
47 |
class OCRDataset(Dataset):
|
|
|
159 |
|
160 |
# --------- Model Save/Load --------- #
|
161 |
def list_saved_models():
|
162 |
+
model_dir = "./models"
|
163 |
+
if not os.path.exists(model_dir):
|
164 |
+
return []
|
165 |
+
return [f for f in os.listdir(model_dir) if f.endswith(".pth")]
|
166 |
+
|
167 |
|
168 |
|
169 |
def save_model(model, path):
|
|
|
185 |
import time
|
186 |
global font_path, ocr_model
|
187 |
|
188 |
+
# Ensure directories exist
|
189 |
+
os.makedirs("./fonts", exist_ok=True)
|
190 |
+
os.makedirs("./models", exist_ok=True)
|
191 |
+
|
192 |
+
# Save uploaded font to ./fonts
|
193 |
font_name = os.path.splitext(os.path.basename(font_file.name))[0]
|
194 |
+
font_path = f"./fonts/{font_name}.ttf"
|
195 |
with open(font_file.name, "rb") as uploaded:
|
196 |
with open(font_path, "wb") as f:
|
197 |
f.write(uploaded.read())
|
198 |
|
199 |
+
# Curriculum learning: label length grows over time
|
200 |
def get_dataset_for_epoch(epoch):
|
201 |
if epoch < epochs // 3:
|
202 |
label_len = (3, 4)
|
|
|
206 |
label_len = (5, 7)
|
207 |
return OCRDataset(font_path, label_length_range=label_len)
|
208 |
|
209 |
+
# Visualize one sample
|
210 |
dataset = get_dataset_for_epoch(0)
|
211 |
+
img, label, _ = dataset[0]
|
|
|
212 |
print("Label:", ''.join([IDX2CHAR[i.item()] for i in label]))
|
213 |
plt.imshow(img.permute(1, 2, 0).squeeze(), cmap='gray')
|
214 |
plt.show()
|
215 |
|
216 |
+
# Model setup
|
217 |
model = OCRModel(num_classes=len(CHAR2IDX)).to(device)
|
218 |
criterion = nn.CTCLoss(blank=BLANK_IDX)
|
219 |
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
220 |
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
|
221 |
|
222 |
for epoch in range(epochs):
|
|
|
223 |
dataset = get_dataset_for_epoch(epoch)
|
224 |
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=custom_collate_fn)
|
225 |
|
226 |
model.train()
|
227 |
running_loss = 0.0
|
228 |
|
229 |
+
# Warmup learning rate
|
230 |
if epoch < 5:
|
231 |
warmup_lr = learning_rate * 0.2
|
232 |
for param_group in optimizer.param_groups:
|
|
|
240 |
targets = targets.to(device)
|
241 |
target_lengths = target_lengths.to(device)
|
242 |
|
243 |
+
output = model(img)
|
244 |
seq_len = output.size(1)
|
245 |
batch_size = img.size(0)
|
246 |
input_lengths = torch.full((batch_size,), seq_len, dtype=torch.long).to(device)
|
247 |
|
248 |
+
log_probs = output.log_softmax(2).transpose(0, 1)
|
249 |
loss = criterion(log_probs, targets, input_lengths, target_lengths)
|
250 |
|
251 |
optimizer.zero_grad()
|
|
|
258 |
scheduler.step(avg_loss)
|
259 |
print(f"[{epoch + 1}/{epochs}] Loss: {avg_loss:.4f}")
|
260 |
|
261 |
+
# Save the model to ./models
|
262 |
timestamp = time.strftime("%Y%m%d%H%M%S")
|
263 |
model_name = f"{font_name}_{epochs}ep_lr{learning_rate:.0e}_{timestamp}.pth"
|
264 |
+
model_path = os.path.join("./models", model_name)
|
265 |
+
save_model(model, model_path)
|
266 |
+
|
267 |
ocr_model = model
|
268 |
+
return f"✅ Training complete! Model saved as '{model_path}'"
|
269 |
|
|
|
270 |
|
271 |
|
272 |
|
|
|
388 |
global font_path
|
389 |
|
390 |
try:
|
391 |
+
if font_file and font_file != "None":
|
392 |
+
font_path = os.path.abspath(font_file)
|
393 |
+
else:
|
394 |
+
font_path = None
|
395 |
+
|
396 |
if font_path is None or not os.path.exists(font_path):
|
397 |
font = ImageFont.load_default()
|
398 |
else:
|
|
|
403 |
images = []
|
404 |
|
405 |
for label in labels:
|
|
|
406 |
bbox = font.getbbox(label)
|
407 |
text_w = bbox[2] - bbox[0]
|
408 |
text_h = bbox[3] - bbox[1]
|
|
|
410 |
img_w = text_w + pad * 2
|
411 |
img_h = text_h + pad * 2
|
412 |
|
|
|
413 |
img = Image.new("L", (img_w, img_h), color=255)
|
414 |
draw = ImageDraw.Draw(img)
|
415 |
draw.text((pad, pad), label, font=font, fill=0)
|
416 |
|
|
|
417 |
safe_label = sanitize_filename(label)
|
418 |
timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f")
|
419 |
label_dir = os.path.join("./labels", safe_label)
|
|
|
433 |
draw.text((10, 50), f"Error: {str(e)}", fill=(255, 0, 0))
|
434 |
return [error_img]
|
435 |
|
436 |
+
def list_fonts():
|
437 |
+
fonts = [f for f in os.listdir() if f.lower().endswith((".ttf", ".otf"))]
|
438 |
+
return ["None"] + fonts if fonts else ["None"]
|
439 |
+
|
440 |
custom_css = """
|
441 |
#label-gallery .gallery-item img {
|
442 |
height: 43px; /* 32pt ≈ 43px */
|
|
|
457 |
|
458 |
# --------- Updated Gradio UI with new tab --------- #
|
459 |
with gr.Blocks(css=custom_css) as demo:
|
460 |
+
with gr.Tab("【Train OCR Model】"):
|
461 |
font_file = gr.File(label="Upload .ttf or .otf font", file_types=[".ttf", ".otf"])
|
462 |
epochs_input = gr.Slider(minimum=1, maximum=4096, value=256, step=1, label="Epochs")
|
463 |
lr_input = gr.Slider(minimum=0.001, maximum=0.1, value=0.05, step=0.001, label="Learning Rate")
|
|
|
466 |
|
467 |
train_button.click(fn=train_model, inputs=[font_file, epochs_input, lr_input], outputs=train_status)
|
468 |
|
469 |
+
with gr.Tab("【Generate Labels】"):
|
470 |
+
font_file_labels = gr.Dropdown(
|
471 |
+
choices=list_fonts(),
|
472 |
+
label="Optional font for label image",
|
473 |
+
interactive=True,
|
474 |
+
)
|
475 |
+
num_labels = gr.Number(value=20, label="Number of labels to generate", precision=0, interactive=True)
|
476 |
+
gen_button = gr.Button("Generate Label Grid")
|
477 |
+
|
478 |
+
gen_button.click(
|
479 |
+
fn=generate_labels,
|
480 |
+
inputs=[font_file_labels, num_labels],
|
481 |
+
outputs=gr.Gallery(
|
482 |
+
label="Generated Labels",
|
483 |
+
columns=16, # 16 tiles per row
|
484 |
+
object_fit="contain", # Maintain aspect ratio
|
485 |
+
height="100%", # Allow full app height
|
486 |
+
elem_id="label-gallery" # For CSS targeting
|
487 |
+
)
|
488 |
|
489 |
+
)
|
490 |
+
with gr.Tab("【Recognize Text】"):
|
491 |
model_list = gr.Dropdown(choices=list_saved_models(), label="Select OCR Model")
|
492 |
refresh_btn = gr.Button("🔄 Refresh Models")
|
493 |
load_model_btn = gr.Button("Load Model") # <-- new button
|
|
|
505 |
|
506 |
predict_btn.click(fn=predict_text, inputs=image_input, outputs=output_text)
|
507 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
508 |
|
|
|
509 |
|
510 |
|
511 |
|