Update app.py
Browse files
app.py
CHANGED
@@ -13,6 +13,8 @@ from torchvision.transforms.functional import to_pil_image
|
|
13 |
import matplotlib.pyplot as plt
|
14 |
import math
|
15 |
from datetime import datetime
|
|
|
|
|
16 |
# --------- Globals --------- #
|
17 |
CHARS = string.ascii_letters + string.digits + string.punctuation
|
18 |
CHAR2IDX = {c: i + 1 for i, c in enumerate(CHARS)}
|
@@ -87,6 +89,9 @@ class OCRModel(nn.Module):
|
|
87 |
return x
|
88 |
|
89 |
|
|
|
|
|
|
|
90 |
def greedy_decode(log_probs):
|
91 |
# log_probs shape: (T, B, C)
|
92 |
# Usually, B=1 during inference
|
@@ -255,7 +260,6 @@ PADDING = 8
|
|
255 |
LABEL_DIR = "./labels"
|
256 |
|
257 |
def generate_labels(font_file=None, num_labels: int = 25):
|
258 |
-
import time
|
259 |
global font_path
|
260 |
|
261 |
try:
|
@@ -264,7 +268,6 @@ def generate_labels(font_file=None, num_labels: int = 25):
|
|
264 |
with open(font_file.name, "rb") as uploaded:
|
265 |
with open(font_path, "wb") as f:
|
266 |
f.write(uploaded.read())
|
267 |
-
|
268 |
if font_path is None or not os.path.exists(font_path):
|
269 |
font = ImageFont.load_default()
|
270 |
else:
|
@@ -275,27 +278,26 @@ def generate_labels(font_file=None, num_labels: int = 25):
|
|
275 |
images = []
|
276 |
|
277 |
for label in labels:
|
278 |
-
# Measure text size
|
279 |
bbox = font.getbbox(label)
|
280 |
text_w = bbox[2] - bbox[0]
|
281 |
text_h = bbox[3] - bbox[1]
|
282 |
-
|
283 |
-
# Add 8px padding
|
284 |
pad = 8
|
285 |
img_w = text_w + pad * 2
|
286 |
img_h = text_h + pad * 2
|
287 |
|
|
|
288 |
img = Image.new("L", (img_w, img_h), color=255)
|
289 |
draw = ImageDraw.Draw(img)
|
290 |
draw.text((pad, pad), label, font=font, fill=0)
|
291 |
|
292 |
-
# Save
|
293 |
-
safe_label =
|
294 |
-
|
|
|
295 |
os.makedirs(label_dir, exist_ok=True)
|
296 |
|
297 |
-
|
298 |
-
filepath = os.path.join(label_dir, f"_{timestamp}.png")
|
299 |
img.save(filepath)
|
300 |
|
301 |
images.append(img)
|
@@ -309,10 +311,23 @@ def generate_labels(font_file=None, num_labels: int = 25):
|
|
309 |
draw.text((10, 50), f"Error: {str(e)}", fill=(255, 0, 0))
|
310 |
return [error_img]
|
311 |
|
312 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
313 |
|
314 |
# --------- Updated Gradio UI with new tab --------- #
|
315 |
-
with gr.Blocks() as demo:
|
316 |
with gr.Tab("1. Upload Font & Train"):
|
317 |
font_file = gr.File(label="Upload .ttf or .otf font", file_types=[".ttf", ".otf"])
|
318 |
epochs_input = gr.Slider(minimum=1, maximum=4096, value=256, step=1, label="Epochs")
|
|
|
13 |
import matplotlib.pyplot as plt
|
14 |
import math
|
15 |
from datetime import datetime
|
16 |
+
import re
|
17 |
+
|
18 |
# --------- Globals --------- #
|
19 |
CHARS = string.ascii_letters + string.digits + string.punctuation
|
20 |
CHAR2IDX = {c: i + 1 for i, c in enumerate(CHARS)}
|
|
|
89 |
return x
|
90 |
|
91 |
|
92 |
+
def sanitize_filename(name):
|
93 |
+
return re.sub(r'[^a-zA-Z0-9_-]', '_', name)
|
94 |
+
|
95 |
def greedy_decode(log_probs):
|
96 |
# log_probs shape: (T, B, C)
|
97 |
# Usually, B=1 during inference
|
|
|
260 |
LABEL_DIR = "./labels"
|
261 |
|
262 |
def generate_labels(font_file=None, num_labels: int = 25):
|
|
|
263 |
global font_path
|
264 |
|
265 |
try:
|
|
|
268 |
with open(font_file.name, "rb") as uploaded:
|
269 |
with open(font_path, "wb") as f:
|
270 |
f.write(uploaded.read())
|
|
|
271 |
if font_path is None or not os.path.exists(font_path):
|
272 |
font = ImageFont.load_default()
|
273 |
else:
|
|
|
278 |
images = []
|
279 |
|
280 |
for label in labels:
|
281 |
+
# Measure text size and calculate padded image dimensions
|
282 |
bbox = font.getbbox(label)
|
283 |
text_w = bbox[2] - bbox[0]
|
284 |
text_h = bbox[3] - bbox[1]
|
|
|
|
|
285 |
pad = 8
|
286 |
img_w = text_w + pad * 2
|
287 |
img_h = text_h + pad * 2
|
288 |
|
289 |
+
# Create image and draw text
|
290 |
img = Image.new("L", (img_w, img_h), color=255)
|
291 |
draw = ImageDraw.Draw(img)
|
292 |
draw.text((pad, pad), label, font=font, fill=0)
|
293 |
|
294 |
+
# Save to ./labels/sanitized_label/timestamp.png
|
295 |
+
safe_label = sanitize_filename(label)
|
296 |
+
timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f")
|
297 |
+
label_dir = os.path.join("./labels", safe_label)
|
298 |
os.makedirs(label_dir, exist_ok=True)
|
299 |
|
300 |
+
filepath = os.path.join(label_dir, f"{timestamp}.png")
|
|
|
301 |
img.save(filepath)
|
302 |
|
303 |
images.append(img)
|
|
|
311 |
draw.text((10, 50), f"Error: {str(e)}", fill=(255, 0, 0))
|
312 |
return [error_img]
|
313 |
|
314 |
+
custom_css = """
|
315 |
+
#label-gallery .gallery-item img {
|
316 |
+
height: 43px; /* 32pt ≈ 43px */
|
317 |
+
width: auto;
|
318 |
+
object-fit: contain;
|
319 |
+
padding: 4px;
|
320 |
+
}
|
321 |
+
|
322 |
+
#label-gallery {
|
323 |
+
flex-grow: 1;
|
324 |
+
overflow-y: auto;
|
325 |
+
height: 100%;
|
326 |
+
}
|
327 |
+
"""
|
328 |
|
329 |
# --------- Updated Gradio UI with new tab --------- #
|
330 |
+
with gr.Blocks(css=custom_css) as demo:
|
331 |
with gr.Tab("1. Upload Font & Train"):
|
332 |
font_file = gr.File(label="Upload .ttf or .otf font", file_types=[".ttf", ".otf"])
|
333 |
epochs_input = gr.Slider(minimum=1, maximum=4096, value=256, step=1, label="Epochs")
|