import gradio as gr import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image, ImageFont, ImageDraw import numpy as np import os import string import cv2 from torchvision.transforms.functional import to_pil_image import matplotlib.pyplot as plt import math from datetime import datetime import re from termcolor import colored from pyctcdecode import BeamSearchDecoderCTC, Alphabet from difflib import SequenceMatcher # --------- Globals --------- # CHARS = string.ascii_letters + string.digits + string.punctuation CHAR2IDX = {c: i + 1 for i, c in enumerate(CHARS)} # Start from 1 CHAR2IDX[""] = 0 # CTC blank IDX2CHAR = {v: k for k, v in CHAR2IDX.items()} BLANK_IDX = 0 IMAGE_HEIGHT = 32 IMAGE_WIDTH = 128 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") font_path = None ocr_model = None # Create vocabulary list (ensure order matches your model’s output indices!) labels = [IDX2CHAR.get(i, "") for i in range(len(IDX2CHAR))] # Wrap in Alphabet alphabet = Alphabet.build_alphabet(labels) # Now initialize decoder correctly decoder = BeamSearchDecoderCTC(alphabet) # Ensure required directories exist at startup os.makedirs("./fonts", exist_ok=True) os.makedirs("./models", exist_ok=True) os.makedirs("./labels", exist_ok=True) # --------- Dataset --------- # class OCRDataset(Dataset): def __init__(self, font_path, size=1000, label_length_range=(4, 7)): self.font = ImageFont.truetype(font_path, 32) self.label_length_range = label_length_range self.samples = [ "".join(np.random.choice(list(CHARS), np.random.randint(*self.label_length_range))) for _ in range(size) ] self.transform = transforms.Compose([ transforms.ToTensor(), # must be first transforms.Normalize((0.5,), (0.5,)), transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)), transforms.RandomApply([transforms.GaussianBlur(kernel_size=3)], p=0.3), transforms.RandomApply([transforms.RandomAffine(degrees=10, translate=(0.1, 0.1))], p=0.3), ]) def __len__(self): return len(self.samples) def __getitem__(self, idx): label = self.samples[idx] # Create an image with padding pad = 8 w = self.font.getlength(label) h = self.font.size img_w, img_h = int(w + 2 * pad), int(h + 2 * pad) img = Image.new("L", (img_w, img_h), 255) draw = ImageDraw.Draw(img) draw.text((pad, pad), label, font=self.font, fill=0) img = self.transform(img) label_encoded = torch.tensor([CHAR2IDX[c] for c in label], dtype=torch.long) label_length = torch.tensor(len(label_encoded), dtype=torch.long) return img, label_encoded, label_length def render_text(self, text): img = Image.new("L", (IMAGE_WIDTH, IMAGE_HEIGHT), color=255) draw = ImageDraw.Draw(img) bbox = self.font.getbbox(text) w, h = bbox[2] - bbox[0], bbox[3] - bbox[1] draw.text(((IMAGE_WIDTH - w) // 2, (IMAGE_HEIGHT - h) // 2), text, font=self.font, fill=0) return img # --------- Model --------- # class OCRModel(nn.Module): def __init__(self, num_classes): super().__init__() self.conv = nn.Sequential( nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d((2, 2), (2, 1)), # height↓2, width↓1 nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d((2, 2), (2, 1)) # height↓2 again, width↓1 ) self.rnn = nn.LSTM(64 * 8, 128, bidirectional=True, num_layers=2, batch_first=True) self.fc = nn.Linear(256, num_classes) with torch.no_grad(): self.fc.bias[0] = -5.0 # discourage blank early on def forward(self, x): b, c, h, w = x.size() x = self.conv(x) x = x.permute(0, 3, 1, 2) x = x.reshape(b, x.size(1), -1) x, _ = self.rnn(x) x = self.fc(x) return x def color_char(c, conf): color_levels = ['\033[31m', '\033[33m', '\033[32m', '\033[36m', '\033[34m', '\033[35m', '\033[0m'] idx = min(int(conf * (len(color_levels) - 1)), len(color_levels) - 1) return f"{color_levels[idx]}{c}\033[0m" def sanitize_filename(name): return re.sub(r'[^a-zA-Z0-9_-]', '_', name) def greedy_decode(log_probs): # log_probs shape: (T, B, C) # Usually, B=1 during inference pred = log_probs.argmax(2).squeeze(1).tolist() # this should give a list of ints print(f"Decoded indices: {pred}") # debug print decoded = [] prev = BLANK_IDX for p in pred: if p != prev and p != BLANK_IDX: decoded.append(IDX2CHAR.get(p, "")) prev = p return ''.join(decoded) # --------- Custom Collate --------- # def custom_collate_fn(batch): images, labels, _ = zip(*batch) images = torch.stack(images, 0) flat_labels = [] label_lengths = [] for label in labels: flat_labels.append(label) label_lengths.append(len(label)) targets = torch.cat(flat_labels) return images, targets, torch.tensor(label_lengths, dtype=torch.long) # --------- Model Save/Load --------- # def list_saved_models(): model_dir = "./models" if not os.path.exists(model_dir): return [] return [f for f in os.listdir(model_dir) if f.endswith(".pth")] def save_model(model, path): torch.save(model.state_dict(), path) def load_model(filename): global ocr_model model_dir = "./models" path = os.path.join(model_dir, filename) if not os.path.exists(path): return f"Model file '{path}' does not exist." model = OCRModel(num_classes=len(CHAR2IDX)) model.load_state_dict(torch.load(path, map_location=device)) model.to(device) model.eval() ocr_model = model return f"Model '{path}' loaded." # --------- Gradio Functions --------- # def train_model(font_file, epochs=100, learning_rate=0.001): import time global font_path, ocr_model # Ensure directories exist os.makedirs("./fonts", exist_ok=True) os.makedirs("./models", exist_ok=True) # Save uploaded font to ./fonts font_name = os.path.splitext(os.path.basename(font_file.name))[0] font_path = f"./fonts/{font_name}.ttf" with open(font_file.name, "rb") as uploaded: with open(font_path, "wb") as f: f.write(uploaded.read()) # Curriculum learning: label length grows over time def get_dataset_for_epoch(epoch): if epoch < epochs // 3: label_len = (3, 4) elif epoch < 2 * epochs // 3: label_len = (4, 6) else: label_len = (5, 7) return OCRDataset(font_path, label_length_range=label_len) # Visualize one sample dataset = get_dataset_for_epoch(0) img, label, _ = dataset[0] print("Label:", ''.join([IDX2CHAR[i.item()] for i in label])) plt.imshow(img.permute(1, 2, 0).squeeze(), cmap='gray') plt.show() # Model setup model = OCRModel(num_classes=len(CHAR2IDX)).to(device) criterion = nn.CTCLoss(blank=BLANK_IDX) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5) for epoch in range(epochs): dataset = get_dataset_for_epoch(epoch) dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=custom_collate_fn) model.train() running_loss = 0.0 # Warmup learning rate if epoch < 5: warmup_lr = learning_rate * 0.2 for param_group in optimizer.param_groups: param_group['lr'] = warmup_lr else: for param_group in optimizer.param_groups: param_group['lr'] = learning_rate for img, targets, target_lengths in dataloader: img = img.to(device) targets = targets.to(device) target_lengths = target_lengths.to(device) output = model(img) seq_len = output.size(1) batch_size = img.size(0) input_lengths = torch.full((batch_size,), seq_len, dtype=torch.long).to(device) log_probs = output.log_softmax(2).transpose(0, 1) loss = criterion(log_probs, targets, input_lengths, target_lengths) optimizer.zero_grad() loss.backward() optimizer.step() running_loss += loss.item() avg_loss = running_loss / len(dataloader) scheduler.step(avg_loss) print(f"[{epoch + 1}/{epochs}] Loss: {avg_loss:.4f}") # Save the model to ./models timestamp = time.strftime("%Y%m%d%H%M%S") model_name = f"{font_name}_{epochs}ep_lr{learning_rate:.0e}_{timestamp}.pth" model_path = os.path.join("./models", model_name) save_model(model, model_path) ocr_model = model return f"✅ Training complete! Model saved as '{model_path}'" def preprocess_image(image: Image.Image): img_cv = np.array(image.convert("L")) img_bin = cv2.adaptiveThreshold(img_cv, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 25, 15) # Invert if background is dark white_px = (img_bin == 255).sum() black_px = (img_bin == 0).sum() if black_px > white_px: img_bin = 255 - img_bin # Resize and pad/crop to (IMAGE_HEIGHT, IMAGE_WIDTH) h, w = img_bin.shape scale = IMAGE_HEIGHT / h new_w = int(w * scale) resized = cv2.resize(img_bin, (new_w, IMAGE_HEIGHT), interpolation=cv2.INTER_AREA) if new_w < IMAGE_WIDTH: pad_width = IMAGE_WIDTH - new_w padded = np.pad(resized, ((0, 0), (0, pad_width)), constant_values=255) else: padded = resized[:, :IMAGE_WIDTH] return to_pil_image(padded) # ROYGBIV color ramp (low → high confidence) CONFIDENCE_COLORS = [ "#FF0000", # Red "#FF7F00", # Orange "#FFFF00", # Yellow "#00FF00", # Green "#00BFFF", # Sky Blue "#0000FF", # Blue "#8B00FF", # Violet ] def confidence_to_color(conf): """ Map confidence (0.0–1.0) to a ROYGBIV-style hex color. """ index = min(int(conf * (len(CONFIDENCE_COLORS) - 1)), len(CONFIDENCE_COLORS) - 1) return CONFIDENCE_COLORS[index] def color_char(c, conf): """ Wrap character `c` in a span tag with color mapped from `conf`. """ color = confidence_to_color(conf) return f'{c}' def predict_text(image: Image.Image, ground_truth: str = None, debug: bool = False): if ocr_model is None: return "Please load or train a model first." processed = preprocess_image(image) transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) img_tensor = transform(processed).unsqueeze(0).to(device) # (1, C, H, W) ocr_model.eval() with torch.no_grad(): output = ocr_model(img_tensor) # (1, T, C) log_probs = output.log_softmax(2)[0] # (T, C) # Decode best beam path (string) pred_text_raw = decoder.decode(log_probs.cpu().numpy()) pred_chars = pred_text_raw.replace("", "") # Remove tokens if present (assuming is in vocab) pred_text = ''.join([c for c in pred_chars if c != ""]) # Confidence: mean max prob per timestep probs = log_probs.exp() max_probs = probs.max(dim=1)[0] avg_conf = max_probs.mean().item() # Color each character (uniform confidence for now) colorized_chars = [color_char(c, avg_conf) for c in pred_text] pretty_output = ''.join(colorized_chars) sim_score = "" if ground_truth: similarity = SequenceMatcher(None, ground_truth, pred_text).ratio() sim_score = f"
Levenshtein Similarity: {similarity:.2%}" if debug: print("Decoded Text:", pred_text) print("Average Confidence:", avg_conf) if ground_truth: print("Ground Truth:", ground_truth) return f"Prediction: {pretty_output}
Confidence: {avg_conf:.2%}{sim_score}" # New helper function: generate label images grid CHARS = string.ascii_letters + string.digits + string.punctuation FONT_SIZE = 32 PADDING = 8 LABEL_DIR = "./labels" def generate_labels(font_file=None, num_labels: int = 25): global font_path try: if font_file and font_file != "None": font_path = os.path.abspath(font_file) else: font_path = None if font_path is None or not os.path.exists(font_path): font = ImageFont.load_default() else: font = ImageFont.truetype(font_path, 32) os.makedirs("./labels", exist_ok=True) labels = ["".join(np.random.choice(list(CHARS), np.random.randint(4, 7))) for _ in range(num_labels)] images = [] for label in labels: bbox = font.getbbox(label) text_w = bbox[2] - bbox[0] text_h = bbox[3] - bbox[1] pad = 8 img_w = text_w + pad * 2 img_h = text_h + pad * 2 img = Image.new("L", (img_w, img_h), color=255) draw = ImageDraw.Draw(img) draw.text((pad, pad), label, font=font, fill=0) safe_label = sanitize_filename(label) timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f") label_dir = os.path.join("./labels", safe_label) os.makedirs(label_dir, exist_ok=True) filepath = os.path.join(label_dir, f"{timestamp}.png") img.save(filepath) images.append(img) return images except Exception as e: print("Error in generate_labels:", e) error_img = Image.new("RGB", (512, 128), color=(255, 255, 255)) draw = ImageDraw.Draw(error_img) draw.text((10, 50), f"Error: {str(e)}", fill=(255, 0, 0)) return [error_img] def list_fonts(): font_dir = "./fonts" if not os.path.exists(font_dir): return ["None"] fonts = [ (f, os.path.join(font_dir, f)) for f in os.listdir(font_dir) if f.lower().endswith((".ttf", ".otf")) ] return [("None", "None")] + fonts custom_css = """ #label-gallery .gallery-item img { height: 43px; /* 32pt ≈ 43px */ width: auto; object-fit: contain; padding: 4px; } #label-gallery { flex-grow: 1; overflow-y: auto; height: 100%; } #output-text { font-size: 12pt; } """ # --------- Updated Gradio UI with new tab --------- # with gr.Blocks(css=custom_css) as demo: with gr.Tab("【Train OCR Model】"): font_file = gr.File(label="Upload .ttf or .otf font", file_types=[".ttf", ".otf"]) epochs_input = gr.Slider(minimum=1, maximum=4096, value=256, step=1, label="Epochs") lr_input = gr.Slider(minimum=0.001, maximum=0.1, value=0.05, step=0.001, label="Learning Rate") train_button = gr.Button("Train OCR Model") train_status = gr.Textbox(label="Status") train_button.click(fn=train_model, inputs=[font_file, epochs_input, lr_input], outputs=train_status) with gr.Tab("【Generate Labels】"): font_file_labels = gr.Dropdown( choices=list_fonts(), label="Optional font for label image", interactive=True, ) num_labels = gr.Number(value=20, label="Number of labels to generate", precision=0, interactive=True) gen_button = gr.Button("Generate Label Grid") gen_button.click( fn=generate_labels, inputs=[font_file_labels, num_labels], outputs=gr.Gallery( label="Generated Labels", columns=16, # 16 tiles per row object_fit="contain", # Maintain aspect ratio height="100%", # Allow full app height elem_id="label-gallery" # For CSS targeting ) ) with gr.Tab("【Recognize Text】"): model_list = gr.Dropdown(choices=list_saved_models(), label="Select OCR Model") refresh_btn = gr.Button("🔄 Refresh Models") load_model_btn = gr.Button("Load Model") # <-- new button image_input = gr.Image(type="pil", label="Upload word strip") predict_btn = gr.Button("Predict") output_text = gr.HTML(label="Recognized Text", elem_id="output-text") model_status = gr.Textbox(label="Model Load Status") # Refresh dropdown choices refresh_btn.click(fn=lambda: gr.update(choices=list_saved_models()), outputs=model_list) # Load model on button click, NOT dropdown change load_model_btn.click(fn=load_model, inputs=model_list, outputs=model_status) predict_btn.click(fn=predict_text, inputs=image_input, outputs=output_text) if __name__ == "__main__": demo.launch()