import csv # import fire import json import numpy as np import os # import pandas as pd import sys import torch import requests from dataclasses import dataclass from PIL import Image from nltk import edit_distance from torchvision import transforms as T from typing import Optional, Callable, Sequence, Tuple from tqdm import tqdm from IndicPhotoOCR.utils.strhub.data.module import SceneTextDataModule from IndicPhotoOCR.utils.strhub.models.utils import load_from_checkpoint model_info = { "assamese": { "path": "models/assamese.ckpt", "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/assamese.ckpt", }, "bengali": { "path": "models/bengali.ckpt", "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/bengali.ckpt", }, "hindi": { "path": "models/hindi.ckpt", "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/hindi.ckpt", }, "gujarati": { "path": "models/gujarati.ckpt", "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/gujarati.ckpt", }, "marathi": { "path": "models/marathi.ckpt", "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/marathi.ckpt", }, "odia": { "path": "models/odia.ckpt", "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/odia.ckpt", }, "punjabi": { "path": "models/punjabi.ckpt", "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/punjabi.ckpt", }, "tamil": { "path": "models/tamil.ckpt", "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/tamil.ckpt", }, "telugu": { "path": "models/telugu.ckpt", "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/telugu.ckpt", } } class PARseqrecogniser: def __init__(self): pass def get_transform(self, img_size: Tuple[int], augment: bool = False, rotation: int = 0): transforms = [] if augment: from .augment import rand_augment_transform transforms.append(rand_augment_transform()) if rotation: transforms.append(lambda img: img.rotate(rotation, expand=True)) transforms.extend([ T.Resize(img_size, T.InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(0.5, 0.5) ]) return T.Compose(transforms) def load_model(self, device, checkpoint): model = load_from_checkpoint(checkpoint).eval().to(device) return model def get_model_output(self, device, model, image_path): hp = model.hparams transform = self.get_transform(hp.img_size, rotation=0) image_name = image_path.split("/")[-1] img = Image.open(image_path).convert('RGB') img = transform(img) logits = model(img.unsqueeze(0).to(device)) probs = logits.softmax(-1) preds, probs = model.tokenizer.decode(probs) text = model.charset_adapter(preds[0]) scores = probs[0].detach().cpu().numpy() return text # Ensure model file exists; download directly if not def ensure_model(self, model_name): model_path = model_info[model_name]["path"] url = model_info[model_name]["url"] root_model_dir = "IndicPhotoOCR/recognition/" model_path = os.path.join(root_model_dir, model_path) if not os.path.exists(model_path): print(f"Model not found locally. Downloading {model_name} from {url}...") # Start the download with a progress bar response = requests.get(url, stream=True) total_size = int(response.headers.get('content-length', 0)) os.makedirs(f"{root_model_dir}/models", exist_ok=True) with open(model_path, "wb") as f, tqdm( desc=model_name, total=total_size, unit='B', unit_scale=True, unit_divisor=1024, ) as bar: for data in response.iter_content(chunk_size=1024): f.write(data) bar.update(len(data)) print(f"Downloaded model for {model_name}.") return model_path def bstr(checkpoint, language, image_dir, save_dir): """ Runs the OCR model to process images and save the output as a JSON file. Args: checkpoint (str): Path to the model checkpoint file. language (str): Language code (e.g., 'hindi', 'english'). image_dir (str): Directory containing the images to process. save_dir (str): Directory where the output JSON file will be saved. Example usage: python your_script.py --checkpoint /path/to/checkpoint.ckpt --language hindi --image_dir /path/to/images --save_dir /path/to/save """ device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") if language != "english": model = load_model(device, checkpoint) else: model = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval().to(device) parseq_dict = {} for image_path in tqdm(os.listdir(image_dir)): assert os.path.exists(os.path.join(image_dir, image_path)) == True, f"{image_path}" text = get_model_output(device, model, os.path.join(image_dir, image_path), language=f"{language}") filename = image_path.split('/')[-1] parseq_dict[filename] = text os.makedirs(save_dir, exist_ok=True) with open(f"{save_dir}/{language}_test.json", 'w') as json_file: json.dump(parseq_dict, json_file, indent=4, ensure_ascii=False) def bstr_onImage(checkpoint, language, image_path): """ Runs the OCR model to process images and save the output as a JSON file. Args: checkpoint (str): Path to the model checkpoint file. language (str): Language code (e.g., 'hindi', 'english'). image_dir (str): Directory containing the images to process. save_dir (str): Directory where the output JSON file will be saved. Example usage: python your_script.py --checkpoint /path/to/checkpoint.ckpt --language hindi --image_dir /path/to/images --save_dir /path/to/save """ device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") if language != "english": model = load_model(device, checkpoint) else: model = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval().to(device) # parseq_dict = {} # for image_path in tqdm(os.listdir(image_dir)): # assert os.path.exists(os.path.join(image_dir, image_path)) == True, f"{image_path}" text = get_model_output(device, model, image_path, language=f"{language}") return text def recognise(self, checkpoint: str, image_path: str, language: str, verbose: bool, device: str) -> str: """ Loads the desired model and returns the recognized word from the specified image. Args: checkpoint (str): Path to the model checkpoint file. language (str): Language code (e.g., 'hindi', 'english'). image_path (str): Path to the image for which text recognition is needed. Returns: str: The recognized text from the image. """ # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if language != "english": model_path = self.ensure_model(checkpoint) model = self.load_model(device, model_path) else: model = torch.hub.load('baudm/parseq', 'parseq', pretrained=True, verbose=verbose).eval().to(device) recognized_text = self.get_model_output(device, model, image_path) return recognized_text # if __name__ == '__main__': # fire.Fire(main)