Spaces:
Running
Running
from surya.model.recognition.encoder import DonutSwinModel | |
from surya.model.table_rec.config import SuryaTableRecConfig, SuryaTableRecDecoderConfig, DonutSwinTableRecConfig, \ | |
SuryaTableRecTextEncoderConfig | |
from surya.model.table_rec.decoder import SuryaTableRecDecoder, SuryaTableRecTextEncoder | |
from surya.model.table_rec.encoderdecoder import TableRecEncoderDecoderModel | |
from surya.settings import settings | |
def load_model(checkpoint=settings.TABLE_REC_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE): | |
config = SuryaTableRecConfig.from_pretrained(checkpoint) | |
decoder_config = config.decoder | |
decoder = SuryaTableRecDecoderConfig(**decoder_config) | |
config.decoder = decoder | |
encoder_config = config.encoder | |
encoder = DonutSwinTableRecConfig(**encoder_config) | |
config.encoder = encoder | |
text_encoder_config = config.text_encoder | |
text_encoder = SuryaTableRecTextEncoderConfig(**text_encoder_config) | |
config.text_encoder = text_encoder | |
model = TableRecEncoderDecoderModel.from_pretrained(checkpoint, config=config, torch_dtype=dtype) | |
assert isinstance(model.decoder, SuryaTableRecDecoder) | |
assert isinstance(model.encoder, DonutSwinModel) | |
assert isinstance(model.text_encoder, SuryaTableRecTextEncoder) | |
model = model.to(device) | |
model = model.eval() | |
print(f"Loaded recognition model {checkpoint} on device {device} with dtype {dtype}") | |
return model |