Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision import transforms | |
from PIL import Image, ImageFont, ImageDraw | |
import numpy as np | |
import os | |
import string | |
import cv2 | |
from torchvision.transforms.functional import to_pil_image | |
import matplotlib.pyplot as plt | |
import math | |
from datetime import datetime | |
import re | |
from termcolor import colored | |
from pyctcdecode import BeamSearchDecoderCTC, Alphabet | |
from difflib import SequenceMatcher | |
# --------- Globals --------- # | |
CHARS = string.ascii_letters + string.digits + string.punctuation | |
CHAR2IDX = {c: i + 1 for i, c in enumerate(CHARS)} # Start from 1 | |
CHAR2IDX["<BLANK>"] = 0 # CTC blank | |
IDX2CHAR = {v: k for k, v in CHAR2IDX.items()} | |
BLANK_IDX = 0 | |
IMAGE_HEIGHT = 32 | |
IMAGE_WIDTH = 128 | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
font_path = None | |
ocr_model = None | |
# Create vocabulary list (ensure order matches your model’s output indices!) | |
labels = [IDX2CHAR.get(i, "") for i in range(len(IDX2CHAR))] | |
# Wrap in Alphabet | |
alphabet = Alphabet.build_alphabet(labels) | |
# Now initialize decoder correctly | |
decoder = BeamSearchDecoderCTC(alphabet) | |
# Ensure required directories exist at startup | |
os.makedirs("./fonts", exist_ok=True) | |
os.makedirs("./models", exist_ok=True) | |
os.makedirs("./labels", exist_ok=True) | |
# --------- Dataset --------- # | |
class OCRDataset(Dataset): | |
def __init__(self, font_path, size=1000, label_length_range=(4, 7)): | |
self.font = ImageFont.truetype(font_path, 32) | |
self.label_length_range = label_length_range | |
self.samples = [ | |
"".join(np.random.choice(list(CHARS), np.random.randint(*self.label_length_range))) | |
for _ in range(size) | |
] | |
self.transform = transforms.Compose([ | |
transforms.ToTensor(), # must be first | |
transforms.Normalize((0.5,), (0.5,)), | |
transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)), | |
transforms.RandomApply([transforms.GaussianBlur(kernel_size=3)], p=0.3), | |
transforms.RandomApply([transforms.RandomAffine(degrees=10, translate=(0.1, 0.1))], p=0.3), | |
]) | |
def __len__(self): | |
return len(self.samples) | |
def __getitem__(self, idx): | |
label = self.samples[idx] | |
# Create an image with padding | |
pad = 8 | |
w = self.font.getlength(label) | |
h = self.font.size | |
img_w, img_h = int(w + 2 * pad), int(h + 2 * pad) | |
img = Image.new("L", (img_w, img_h), 255) | |
draw = ImageDraw.Draw(img) | |
draw.text((pad, pad), label, font=self.font, fill=0) | |
img = self.transform(img) | |
label_encoded = torch.tensor([CHAR2IDX[c] for c in label], dtype=torch.long) | |
label_length = torch.tensor(len(label_encoded), dtype=torch.long) | |
return img, label_encoded, label_length | |
def render_text(self, text): | |
img = Image.new("L", (IMAGE_WIDTH, IMAGE_HEIGHT), color=255) | |
draw = ImageDraw.Draw(img) | |
bbox = self.font.getbbox(text) | |
w, h = bbox[2] - bbox[0], bbox[3] - bbox[1] | |
draw.text(((IMAGE_WIDTH - w) // 2, (IMAGE_HEIGHT - h) // 2), text, font=self.font, fill=0) | |
return img | |
# --------- Model --------- # | |
class OCRModel(nn.Module): | |
def __init__(self, num_classes): | |
super().__init__() | |
self.conv = nn.Sequential( | |
nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d((2, 2), (2, 1)), # height↓2, width↓1 | |
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d((2, 2), (2, 1)) # height↓2 again, width↓1 | |
) | |
self.rnn = nn.LSTM(64 * 8, 128, bidirectional=True, num_layers=2, batch_first=True) | |
self.fc = nn.Linear(256, num_classes) | |
with torch.no_grad(): | |
self.fc.bias[0] = -5.0 # discourage blank early on | |
def forward(self, x): | |
b, c, h, w = x.size() | |
x = self.conv(x) | |
x = x.permute(0, 3, 1, 2) | |
x = x.reshape(b, x.size(1), -1) | |
x, _ = self.rnn(x) | |
x = self.fc(x) | |
return x | |
def color_char(c, conf): | |
color_levels = ['\033[31m', '\033[33m', '\033[32m', '\033[36m', '\033[34m', '\033[35m', '\033[0m'] | |
idx = min(int(conf * (len(color_levels) - 1)), len(color_levels) - 1) | |
return f"{color_levels[idx]}{c}\033[0m" | |
def sanitize_filename(name): | |
return re.sub(r'[^a-zA-Z0-9_-]', '_', name) | |
def greedy_decode(log_probs): | |
# log_probs shape: (T, B, C) | |
# Usually, B=1 during inference | |
pred = log_probs.argmax(2).squeeze(1).tolist() # this should give a list of ints | |
print(f"Decoded indices: {pred}") # debug print | |
decoded = [] | |
prev = BLANK_IDX | |
for p in pred: | |
if p != prev and p != BLANK_IDX: | |
decoded.append(IDX2CHAR.get(p, "")) | |
prev = p | |
return ''.join(decoded) | |
# --------- Custom Collate --------- # | |
def custom_collate_fn(batch): | |
images, labels, _ = zip(*batch) | |
images = torch.stack(images, 0) | |
flat_labels = [] | |
label_lengths = [] | |
for label in labels: | |
flat_labels.append(label) | |
label_lengths.append(len(label)) | |
targets = torch.cat(flat_labels) | |
return images, targets, torch.tensor(label_lengths, dtype=torch.long) | |
# --------- Model Save/Load --------- # | |
def list_saved_models(): | |
model_dir = "./models" | |
if not os.path.exists(model_dir): | |
return [] | |
return [f for f in os.listdir(model_dir) if f.endswith(".pth")] | |
def save_model(model, path): | |
torch.save(model.state_dict(), path) | |
def load_model(filename): | |
global ocr_model | |
model_dir = "./models" | |
path = os.path.join(model_dir, filename) | |
if not os.path.exists(path): | |
return f"Model file '{path}' does not exist." | |
model = OCRModel(num_classes=len(CHAR2IDX)) | |
model.load_state_dict(torch.load(path, map_location=device)) | |
model.to(device) | |
model.eval() | |
ocr_model = model | |
return f"Model '{path}' loaded." | |
# --------- Gradio Functions --------- # | |
def train_model(font_file, epochs=100, learning_rate=0.001): | |
import time | |
global font_path, ocr_model | |
# Ensure directories exist | |
os.makedirs("./fonts", exist_ok=True) | |
os.makedirs("./models", exist_ok=True) | |
# Save uploaded font to ./fonts | |
font_name = os.path.splitext(os.path.basename(font_file.name))[0] | |
font_path = f"./fonts/{font_name}.ttf" | |
with open(font_file.name, "rb") as uploaded: | |
with open(font_path, "wb") as f: | |
f.write(uploaded.read()) | |
# Curriculum learning: label length grows over time | |
def get_dataset_for_epoch(epoch): | |
if epoch < epochs // 3: | |
label_len = (3, 4) | |
elif epoch < 2 * epochs // 3: | |
label_len = (4, 6) | |
else: | |
label_len = (5, 7) | |
return OCRDataset(font_path, label_length_range=label_len) | |
# Visualize one sample | |
dataset = get_dataset_for_epoch(0) | |
img, label, _ = dataset[0] | |
print("Label:", ''.join([IDX2CHAR[i.item()] for i in label])) | |
plt.imshow(img.permute(1, 2, 0).squeeze(), cmap='gray') | |
plt.show() | |
# Model setup | |
model = OCRModel(num_classes=len(CHAR2IDX)).to(device) | |
criterion = nn.CTCLoss(blank=BLANK_IDX) | |
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) | |
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5) | |
for epoch in range(epochs): | |
dataset = get_dataset_for_epoch(epoch) | |
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=custom_collate_fn) | |
model.train() | |
running_loss = 0.0 | |
# Warmup learning rate | |
if epoch < 5: | |
warmup_lr = learning_rate * 0.2 | |
for param_group in optimizer.param_groups: | |
param_group['lr'] = warmup_lr | |
else: | |
for param_group in optimizer.param_groups: | |
param_group['lr'] = learning_rate | |
for img, targets, target_lengths in dataloader: | |
img = img.to(device) | |
targets = targets.to(device) | |
target_lengths = target_lengths.to(device) | |
output = model(img) | |
seq_len = output.size(1) | |
batch_size = img.size(0) | |
input_lengths = torch.full((batch_size,), seq_len, dtype=torch.long).to(device) | |
log_probs = output.log_softmax(2).transpose(0, 1) | |
loss = criterion(log_probs, targets, input_lengths, target_lengths) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
running_loss += loss.item() | |
avg_loss = running_loss / len(dataloader) | |
scheduler.step(avg_loss) | |
print(f"[{epoch + 1}/{epochs}] Loss: {avg_loss:.4f}") | |
# Save the model to ./models | |
timestamp = time.strftime("%Y%m%d%H%M%S") | |
model_name = f"{font_name}_{epochs}ep_lr{learning_rate:.0e}_{timestamp}.pth" | |
model_path = os.path.join("./models", model_name) | |
save_model(model, model_path) | |
ocr_model = model | |
return f"✅ Training complete! Model saved as '{model_path}'" | |
def preprocess_image(image: Image.Image): | |
img_cv = np.array(image.convert("L")) | |
img_bin = cv2.adaptiveThreshold(img_cv, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, | |
cv2.THRESH_BINARY_INV, 25, 15) | |
# Invert if background is dark | |
white_px = (img_bin == 255).sum() | |
black_px = (img_bin == 0).sum() | |
if black_px > white_px: | |
img_bin = 255 - img_bin | |
# Resize and pad/crop to (IMAGE_HEIGHT, IMAGE_WIDTH) | |
h, w = img_bin.shape | |
scale = IMAGE_HEIGHT / h | |
new_w = int(w * scale) | |
resized = cv2.resize(img_bin, (new_w, IMAGE_HEIGHT), interpolation=cv2.INTER_AREA) | |
if new_w < IMAGE_WIDTH: | |
pad_width = IMAGE_WIDTH - new_w | |
padded = np.pad(resized, ((0, 0), (0, pad_width)), constant_values=255) | |
else: | |
padded = resized[:, :IMAGE_WIDTH] | |
return to_pil_image(padded) | |
# ROYGBIV color ramp (low → high confidence) | |
CONFIDENCE_COLORS = [ | |
"#FF0000", # Red | |
"#FF7F00", # Orange | |
"#FFFF00", # Yellow | |
"#00FF00", # Green | |
"#00BFFF", # Sky Blue | |
"#0000FF", # Blue | |
"#8B00FF", # Violet | |
] | |
def confidence_to_color(conf): | |
""" | |
Map confidence (0.0–1.0) to a ROYGBIV-style hex color. | |
""" | |
index = min(int(conf * (len(CONFIDENCE_COLORS) - 1)), len(CONFIDENCE_COLORS) - 1) | |
return CONFIDENCE_COLORS[index] | |
def color_char(c, conf): | |
""" | |
Wrap character `c` in a span tag with color mapped from `conf`. | |
""" | |
color = confidence_to_color(conf) | |
return f'<span style="color:{color}; font-size:12pt; font-weight:bold;">{c}</span>' | |
def predict_text(image: Image.Image, ground_truth: str = None, debug: bool = False): | |
if ocr_model is None: | |
return "Please load or train a model first." | |
processed = preprocess_image(image) | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.5,), (0.5,)) | |
]) | |
img_tensor = transform(processed).unsqueeze(0).to(device) # (1, C, H, W) | |
ocr_model.eval() | |
with torch.no_grad(): | |
output = ocr_model(img_tensor) # (1, T, C) | |
log_probs = output.log_softmax(2)[0] # (T, C) | |
# Decode best beam path (string) | |
pred_text_raw = decoder.decode(log_probs.cpu().numpy()) | |
pred_chars = pred_text_raw.replace("<BLANK>", "") | |
# Remove <BLANK> tokens if present (assuming <BLANK> is in vocab) | |
pred_text = ''.join([c for c in pred_chars if c != "<BLANK>"]) | |
# Confidence: mean max prob per timestep | |
probs = log_probs.exp() | |
max_probs = probs.max(dim=1)[0] | |
avg_conf = max_probs.mean().item() | |
# Color each character (uniform confidence for now) | |
colorized_chars = [color_char(c, avg_conf) for c in pred_text] | |
pretty_output = ''.join(colorized_chars) | |
sim_score = "" | |
if ground_truth: | |
similarity = SequenceMatcher(None, ground_truth, pred_text).ratio() | |
sim_score = f"<br><strong>Levenshtein Similarity:</strong> {similarity:.2%}" | |
if debug: | |
print("Decoded Text:", pred_text) | |
print("Average Confidence:", avg_conf) | |
if ground_truth: | |
print("Ground Truth:", ground_truth) | |
return f"<strong>Prediction:</strong> <strong>{pretty_output}</strong><br><strong>Confidence:</strong> {avg_conf:.2%}{sim_score}" | |
# New helper function: generate label images grid | |
CHARS = string.ascii_letters + string.digits + string.punctuation | |
FONT_SIZE = 32 | |
PADDING = 8 | |
LABEL_DIR = "./labels" | |
def generate_labels(font_file=None, num_labels: int = 25): | |
global font_path | |
try: | |
if font_file and font_file != "None": | |
font_path = os.path.abspath(font_file) | |
else: | |
font_path = None | |
if font_path is None or not os.path.exists(font_path): | |
font = ImageFont.load_default() | |
else: | |
font = ImageFont.truetype(font_path, 32) | |
os.makedirs("./labels", exist_ok=True) | |
labels = ["".join(np.random.choice(list(CHARS), np.random.randint(4, 7))) for _ in range(num_labels)] | |
images = [] | |
for label in labels: | |
bbox = font.getbbox(label) | |
text_w = bbox[2] - bbox[0] | |
text_h = bbox[3] - bbox[1] | |
pad = 8 | |
img_w = text_w + pad * 2 | |
img_h = text_h + pad * 2 | |
img = Image.new("L", (img_w, img_h), color=255) | |
draw = ImageDraw.Draw(img) | |
draw.text((pad, pad), label, font=font, fill=0) | |
safe_label = sanitize_filename(label) | |
timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f") | |
label_dir = os.path.join("./labels", safe_label) | |
os.makedirs(label_dir, exist_ok=True) | |
filepath = os.path.join(label_dir, f"{timestamp}.png") | |
img.save(filepath) | |
images.append(img) | |
return images | |
except Exception as e: | |
print("Error in generate_labels:", e) | |
error_img = Image.new("RGB", (512, 128), color=(255, 255, 255)) | |
draw = ImageDraw.Draw(error_img) | |
draw.text((10, 50), f"Error: {str(e)}", fill=(255, 0, 0)) | |
return [error_img] | |
def list_fonts(): | |
font_dir = "./fonts" | |
if not os.path.exists(font_dir): | |
return ["None"] | |
fonts = [ | |
(f, os.path.join(font_dir, f)) for f in os.listdir(font_dir) | |
if f.lower().endswith((".ttf", ".otf")) | |
] | |
return [("None", "None")] + fonts | |
custom_css = """ | |
#label-gallery .gallery-item img { | |
height: 43px; /* 32pt ≈ 43px */ | |
width: auto; | |
object-fit: contain; | |
padding: 4px; | |
} | |
#label-gallery { | |
flex-grow: 1; | |
overflow-y: auto; | |
height: 100%; | |
} | |
#output-text { | |
font-size: 12pt; | |
} | |
""" | |
# --------- Updated Gradio UI with new tab --------- # | |
with gr.Blocks(css=custom_css) as demo: | |
with gr.Tab("【Train OCR Model】"): | |
font_file = gr.File(label="Upload .ttf or .otf font", file_types=[".ttf", ".otf"]) | |
epochs_input = gr.Slider(minimum=1, maximum=4096, value=256, step=1, label="Epochs") | |
lr_input = gr.Slider(minimum=0.001, maximum=0.1, value=0.05, step=0.001, label="Learning Rate") | |
train_button = gr.Button("Train OCR Model") | |
train_status = gr.Textbox(label="Status") | |
train_button.click(fn=train_model, inputs=[font_file, epochs_input, lr_input], outputs=train_status) | |
with gr.Tab("【Generate Labels】"): | |
font_file_labels = gr.Dropdown( | |
choices=list_fonts(), | |
label="Optional font for label image", | |
interactive=True, | |
) | |
num_labels = gr.Number(value=20, label="Number of labels to generate", precision=0, interactive=True) | |
gen_button = gr.Button("Generate Label Grid") | |
gen_button.click( | |
fn=generate_labels, | |
inputs=[font_file_labels, num_labels], | |
outputs=gr.Gallery( | |
label="Generated Labels", | |
columns=16, # 16 tiles per row | |
object_fit="contain", # Maintain aspect ratio | |
height="100%", # Allow full app height | |
elem_id="label-gallery" # For CSS targeting | |
) | |
) | |
with gr.Tab("【Recognize Text】"): | |
model_list = gr.Dropdown(choices=list_saved_models(), label="Select OCR Model") | |
refresh_btn = gr.Button("🔄 Refresh Models") | |
load_model_btn = gr.Button("Load Model") # <-- new button | |
image_input = gr.Image(type="pil", label="Upload word strip") | |
predict_btn = gr.Button("Predict") | |
output_text = gr.HTML(label="Recognized Text", elem_id="output-text") | |
model_status = gr.Textbox(label="Model Load Status") | |
# Refresh dropdown choices | |
refresh_btn.click(fn=lambda: gr.update(choices=list_saved_models()), outputs=model_list) | |
# Load model on button click, NOT dropdown change | |
load_model_btn.click(fn=load_model, inputs=model_list, outputs=model_status) | |
predict_btn.click(fn=predict_text, inputs=image_input, outputs=output_text) | |
if __name__ == "__main__": | |
demo.launch() | |