taellinglin's picture
Ling Lin
ed66c76
raw
history blame
17.2 kB
import gradio as gr
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image, ImageFont, ImageDraw
import numpy as np
import os
import string
import cv2
from torchvision.transforms.functional import to_pil_image
import matplotlib.pyplot as plt
import math
from datetime import datetime
import re
from termcolor import colored
from pyctcdecode import BeamSearchDecoderCTC, Alphabet
from difflib import SequenceMatcher
# --------- Globals --------- #
CHARS = string.ascii_letters + string.digits + string.punctuation
CHAR2IDX = {c: i + 1 for i, c in enumerate(CHARS)} # Start from 1
CHAR2IDX["<BLANK>"] = 0 # CTC blank
IDX2CHAR = {v: k for k, v in CHAR2IDX.items()}
BLANK_IDX = 0
IMAGE_HEIGHT = 32
IMAGE_WIDTH = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
font_path = None
ocr_model = None
# Create vocabulary list (ensure order matches your model’s output indices!)
labels = [IDX2CHAR.get(i, "") for i in range(len(IDX2CHAR))]
# Wrap in Alphabet
alphabet = Alphabet.build_alphabet(labels)
# Now initialize decoder correctly
decoder = BeamSearchDecoderCTC(alphabet)
# Ensure required directories exist at startup
os.makedirs("./fonts", exist_ok=True)
os.makedirs("./models", exist_ok=True)
os.makedirs("./labels", exist_ok=True)
# --------- Dataset --------- #
class OCRDataset(Dataset):
def __init__(self, font_path, size=1000, label_length_range=(4, 7)):
self.font = ImageFont.truetype(font_path, 32)
self.label_length_range = label_length_range
self.samples = [
"".join(np.random.choice(list(CHARS), np.random.randint(*self.label_length_range)))
for _ in range(size)
]
self.transform = transforms.Compose([
transforms.ToTensor(), # must be first
transforms.Normalize((0.5,), (0.5,)),
transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
transforms.RandomApply([transforms.GaussianBlur(kernel_size=3)], p=0.3),
transforms.RandomApply([transforms.RandomAffine(degrees=10, translate=(0.1, 0.1))], p=0.3),
])
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
label = self.samples[idx]
# Create an image with padding
pad = 8
w = self.font.getlength(label)
h = self.font.size
img_w, img_h = int(w + 2 * pad), int(h + 2 * pad)
img = Image.new("L", (img_w, img_h), 255)
draw = ImageDraw.Draw(img)
draw.text((pad, pad), label, font=self.font, fill=0)
img = self.transform(img)
label_encoded = torch.tensor([CHAR2IDX[c] for c in label], dtype=torch.long)
label_length = torch.tensor(len(label_encoded), dtype=torch.long)
return img, label_encoded, label_length
def render_text(self, text):
img = Image.new("L", (IMAGE_WIDTH, IMAGE_HEIGHT), color=255)
draw = ImageDraw.Draw(img)
bbox = self.font.getbbox(text)
w, h = bbox[2] - bbox[0], bbox[3] - bbox[1]
draw.text(((IMAGE_WIDTH - w) // 2, (IMAGE_HEIGHT - h) // 2), text, font=self.font, fill=0)
return img
# --------- Model --------- #
class OCRModel(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d((2, 2), (2, 1)), # height↓2, width↓1
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d((2, 2), (2, 1)) # height↓2 again, width↓1
)
self.rnn = nn.LSTM(64 * 8, 128, bidirectional=True, num_layers=2, batch_first=True)
self.fc = nn.Linear(256, num_classes)
with torch.no_grad():
self.fc.bias[0] = -5.0 # discourage blank early on
def forward(self, x):
b, c, h, w = x.size()
x = self.conv(x)
x = x.permute(0, 3, 1, 2)
x = x.reshape(b, x.size(1), -1)
x, _ = self.rnn(x)
x = self.fc(x)
return x
def color_char(c, conf):
color_levels = ['\033[31m', '\033[33m', '\033[32m', '\033[36m', '\033[34m', '\033[35m', '\033[0m']
idx = min(int(conf * (len(color_levels) - 1)), len(color_levels) - 1)
return f"{color_levels[idx]}{c}\033[0m"
def sanitize_filename(name):
return re.sub(r'[^a-zA-Z0-9_-]', '_', name)
def greedy_decode(log_probs):
# log_probs shape: (T, B, C)
# Usually, B=1 during inference
pred = log_probs.argmax(2).squeeze(1).tolist() # this should give a list of ints
print(f"Decoded indices: {pred}") # debug print
decoded = []
prev = BLANK_IDX
for p in pred:
if p != prev and p != BLANK_IDX:
decoded.append(IDX2CHAR.get(p, ""))
prev = p
return ''.join(decoded)
# --------- Custom Collate --------- #
def custom_collate_fn(batch):
images, labels, _ = zip(*batch)
images = torch.stack(images, 0)
flat_labels = []
label_lengths = []
for label in labels:
flat_labels.append(label)
label_lengths.append(len(label))
targets = torch.cat(flat_labels)
return images, targets, torch.tensor(label_lengths, dtype=torch.long)
# --------- Model Save/Load --------- #
def list_saved_models():
model_dir = "./models"
if not os.path.exists(model_dir):
return []
return [f for f in os.listdir(model_dir) if f.endswith(".pth")]
def save_model(model, path):
torch.save(model.state_dict(), path)
def load_model(filename):
global ocr_model
model_dir = "./models"
path = os.path.join(model_dir, filename)
if not os.path.exists(path):
return f"Model file '{path}' does not exist."
model = OCRModel(num_classes=len(CHAR2IDX))
model.load_state_dict(torch.load(path, map_location=device))
model.to(device)
model.eval()
ocr_model = model
return f"Model '{path}' loaded."
# --------- Gradio Functions --------- #
def train_model(font_file, epochs=100, learning_rate=0.001):
import time
global font_path, ocr_model
# Ensure directories exist
os.makedirs("./fonts", exist_ok=True)
os.makedirs("./models", exist_ok=True)
# Save uploaded font to ./fonts
font_name = os.path.splitext(os.path.basename(font_file.name))[0]
font_path = f"./fonts/{font_name}.ttf"
with open(font_file.name, "rb") as uploaded:
with open(font_path, "wb") as f:
f.write(uploaded.read())
# Curriculum learning: label length grows over time
def get_dataset_for_epoch(epoch):
if epoch < epochs // 3:
label_len = (3, 4)
elif epoch < 2 * epochs // 3:
label_len = (4, 6)
else:
label_len = (5, 7)
return OCRDataset(font_path, label_length_range=label_len)
# Visualize one sample
dataset = get_dataset_for_epoch(0)
img, label, _ = dataset[0]
print("Label:", ''.join([IDX2CHAR[i.item()] for i in label]))
plt.imshow(img.permute(1, 2, 0).squeeze(), cmap='gray')
plt.show()
# Model setup
model = OCRModel(num_classes=len(CHAR2IDX)).to(device)
criterion = nn.CTCLoss(blank=BLANK_IDX)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
for epoch in range(epochs):
dataset = get_dataset_for_epoch(epoch)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=custom_collate_fn)
model.train()
running_loss = 0.0
# Warmup learning rate
if epoch < 5:
warmup_lr = learning_rate * 0.2
for param_group in optimizer.param_groups:
param_group['lr'] = warmup_lr
else:
for param_group in optimizer.param_groups:
param_group['lr'] = learning_rate
for img, targets, target_lengths in dataloader:
img = img.to(device)
targets = targets.to(device)
target_lengths = target_lengths.to(device)
output = model(img)
seq_len = output.size(1)
batch_size = img.size(0)
input_lengths = torch.full((batch_size,), seq_len, dtype=torch.long).to(device)
log_probs = output.log_softmax(2).transpose(0, 1)
loss = criterion(log_probs, targets, input_lengths, target_lengths)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
avg_loss = running_loss / len(dataloader)
scheduler.step(avg_loss)
print(f"[{epoch + 1}/{epochs}] Loss: {avg_loss:.4f}")
# Save the model to ./models
timestamp = time.strftime("%Y%m%d%H%M%S")
model_name = f"{font_name}_{epochs}ep_lr{learning_rate:.0e}_{timestamp}.pth"
model_path = os.path.join("./models", model_name)
save_model(model, model_path)
ocr_model = model
return f"✅ Training complete! Model saved as '{model_path}'"
def preprocess_image(image: Image.Image):
img_cv = np.array(image.convert("L"))
img_bin = cv2.adaptiveThreshold(img_cv, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY_INV, 25, 15)
# Invert if background is dark
white_px = (img_bin == 255).sum()
black_px = (img_bin == 0).sum()
if black_px > white_px:
img_bin = 255 - img_bin
# Resize and pad/crop to (IMAGE_HEIGHT, IMAGE_WIDTH)
h, w = img_bin.shape
scale = IMAGE_HEIGHT / h
new_w = int(w * scale)
resized = cv2.resize(img_bin, (new_w, IMAGE_HEIGHT), interpolation=cv2.INTER_AREA)
if new_w < IMAGE_WIDTH:
pad_width = IMAGE_WIDTH - new_w
padded = np.pad(resized, ((0, 0), (0, pad_width)), constant_values=255)
else:
padded = resized[:, :IMAGE_WIDTH]
return to_pil_image(padded)
# ROYGBIV color ramp (low → high confidence)
CONFIDENCE_COLORS = [
"#FF0000", # Red
"#FF7F00", # Orange
"#FFFF00", # Yellow
"#00FF00", # Green
"#00BFFF", # Sky Blue
"#0000FF", # Blue
"#8B00FF", # Violet
]
def confidence_to_color(conf):
"""
Map confidence (0.0–1.0) to a ROYGBIV-style hex color.
"""
index = min(int(conf * (len(CONFIDENCE_COLORS) - 1)), len(CONFIDENCE_COLORS) - 1)
return CONFIDENCE_COLORS[index]
def color_char(c, conf):
"""
Wrap character `c` in a span tag with color mapped from `conf`.
"""
color = confidence_to_color(conf)
return f'<span style="color:{color}; font-size:12pt; font-weight:bold;">{c}</span>'
def predict_text(image: Image.Image, ground_truth: str = None, debug: bool = False):
if ocr_model is None:
return "Please load or train a model first."
processed = preprocess_image(image)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
img_tensor = transform(processed).unsqueeze(0).to(device) # (1, C, H, W)
ocr_model.eval()
with torch.no_grad():
output = ocr_model(img_tensor) # (1, T, C)
log_probs = output.log_softmax(2)[0] # (T, C)
# Decode best beam path (string)
pred_text_raw = decoder.decode(log_probs.cpu().numpy())
pred_chars = pred_text_raw.replace("<BLANK>", "")
# Remove <BLANK> tokens if present (assuming <BLANK> is in vocab)
pred_text = ''.join([c for c in pred_chars if c != "<BLANK>"])
# Confidence: mean max prob per timestep
probs = log_probs.exp()
max_probs = probs.max(dim=1)[0]
avg_conf = max_probs.mean().item()
# Color each character (uniform confidence for now)
colorized_chars = [color_char(c, avg_conf) for c in pred_text]
pretty_output = ''.join(colorized_chars)
sim_score = ""
if ground_truth:
similarity = SequenceMatcher(None, ground_truth, pred_text).ratio()
sim_score = f"<br><strong>Levenshtein Similarity:</strong> {similarity:.2%}"
if debug:
print("Decoded Text:", pred_text)
print("Average Confidence:", avg_conf)
if ground_truth:
print("Ground Truth:", ground_truth)
return f"<strong>Prediction:</strong> <strong>{pretty_output}</strong><br><strong>Confidence:</strong> {avg_conf:.2%}{sim_score}"
# New helper function: generate label images grid
CHARS = string.ascii_letters + string.digits + string.punctuation
FONT_SIZE = 32
PADDING = 8
LABEL_DIR = "./labels"
def generate_labels(font_file=None, num_labels: int = 25):
global font_path
try:
if font_file and font_file != "None":
font_path = os.path.abspath(font_file)
else:
font_path = None
if font_path is None or not os.path.exists(font_path):
font = ImageFont.load_default()
else:
font = ImageFont.truetype(font_path, 32)
os.makedirs("./labels", exist_ok=True)
labels = ["".join(np.random.choice(list(CHARS), np.random.randint(4, 7))) for _ in range(num_labels)]
images = []
for label in labels:
bbox = font.getbbox(label)
text_w = bbox[2] - bbox[0]
text_h = bbox[3] - bbox[1]
pad = 8
img_w = text_w + pad * 2
img_h = text_h + pad * 2
img = Image.new("L", (img_w, img_h), color=255)
draw = ImageDraw.Draw(img)
draw.text((pad, pad), label, font=font, fill=0)
safe_label = sanitize_filename(label)
timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f")
label_dir = os.path.join("./labels", safe_label)
os.makedirs(label_dir, exist_ok=True)
filepath = os.path.join(label_dir, f"{timestamp}.png")
img.save(filepath)
images.append(img)
return images
except Exception as e:
print("Error in generate_labels:", e)
error_img = Image.new("RGB", (512, 128), color=(255, 255, 255))
draw = ImageDraw.Draw(error_img)
draw.text((10, 50), f"Error: {str(e)}", fill=(255, 0, 0))
return [error_img]
def list_fonts():
font_dir = "./fonts"
if not os.path.exists(font_dir):
return ["None"]
fonts = [
(f, os.path.join(font_dir, f)) for f in os.listdir(font_dir)
if f.lower().endswith((".ttf", ".otf"))
]
return [("None", "None")] + fonts
custom_css = """
#label-gallery .gallery-item img {
height: 43px; /* 32pt ≈ 43px */
width: auto;
object-fit: contain;
padding: 4px;
}
#label-gallery {
flex-grow: 1;
overflow-y: auto;
height: 100%;
}
#output-text {
font-size: 12pt;
}
"""
# --------- Updated Gradio UI with new tab --------- #
with gr.Blocks(css=custom_css) as demo:
with gr.Tab("【Train OCR Model】"):
font_file = gr.File(label="Upload .ttf or .otf font", file_types=[".ttf", ".otf"])
epochs_input = gr.Slider(minimum=1, maximum=4096, value=256, step=1, label="Epochs")
lr_input = gr.Slider(minimum=0.001, maximum=0.1, value=0.05, step=0.001, label="Learning Rate")
train_button = gr.Button("Train OCR Model")
train_status = gr.Textbox(label="Status")
train_button.click(fn=train_model, inputs=[font_file, epochs_input, lr_input], outputs=train_status)
with gr.Tab("【Generate Labels】"):
font_file_labels = gr.Dropdown(
choices=list_fonts(),
label="Optional font for label image",
interactive=True,
)
num_labels = gr.Number(value=20, label="Number of labels to generate", precision=0, interactive=True)
gen_button = gr.Button("Generate Label Grid")
gen_button.click(
fn=generate_labels,
inputs=[font_file_labels, num_labels],
outputs=gr.Gallery(
label="Generated Labels",
columns=16, # 16 tiles per row
object_fit="contain", # Maintain aspect ratio
height="100%", # Allow full app height
elem_id="label-gallery" # For CSS targeting
)
)
with gr.Tab("【Recognize Text】"):
model_list = gr.Dropdown(choices=list_saved_models(), label="Select OCR Model")
refresh_btn = gr.Button("🔄 Refresh Models")
load_model_btn = gr.Button("Load Model") # <-- new button
image_input = gr.Image(type="pil", label="Upload word strip")
predict_btn = gr.Button("Predict")
output_text = gr.HTML(label="Recognized Text", elem_id="output-text")
model_status = gr.Textbox(label="Model Load Status")
# Refresh dropdown choices
refresh_btn.click(fn=lambda: gr.update(choices=list_saved_models()), outputs=model_list)
# Load model on button click, NOT dropdown change
load_model_btn.click(fn=load_model, inputs=model_list, outputs=model_status)
predict_btn.click(fn=predict_text, inputs=image_input, outputs=output_text)
if __name__ == "__main__":
demo.launch()