import warnings import torch warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated") import logging logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) from typing import List, Optional, Tuple from surya.model.recognition.encoderdecoder import OCREncoderDecoderModel from surya.model.recognition.config import DonutSwinConfig, SuryaOCRConfig, SuryaOCRDecoderConfig, SuryaOCRTextEncoderConfig from surya.model.recognition.encoder import DonutSwinModel from surya.model.recognition.decoder import SuryaOCRDecoder, SuryaOCRTextEncoder from surya.settings import settings if not settings.ENABLE_EFFICIENT_ATTENTION: print("Efficient attention is disabled. This will use significantly more VRAM.") torch.backends.cuda.enable_mem_efficient_sdp(False) torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_math_sdp(True) def load_model(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE): config = SuryaOCRConfig.from_pretrained(checkpoint) decoder_config = config.decoder decoder = SuryaOCRDecoderConfig(**decoder_config) config.decoder = decoder encoder_config = config.encoder encoder = DonutSwinConfig(**encoder_config) config.encoder = encoder text_encoder_config = config.text_encoder text_encoder = SuryaOCRTextEncoderConfig(**text_encoder_config) config.text_encoder = text_encoder model = OCREncoderDecoderModel.from_pretrained(checkpoint, config=config, torch_dtype=dtype) assert isinstance(model.decoder, SuryaOCRDecoder) assert isinstance(model.encoder, DonutSwinModel) assert isinstance(model.text_encoder, SuryaOCRTextEncoder) model = model.to(device) model = model.eval() print(f"Loaded recognition model {checkpoint} on device {device} with dtype {dtype}") return model