File size: 1,456 Bytes
52f1bcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from surya.model.recognition.encoder import DonutSwinModel
from surya.model.table_rec.config import SuryaTableRecConfig, SuryaTableRecDecoderConfig, DonutSwinTableRecConfig, \
    SuryaTableRecTextEncoderConfig
from surya.model.table_rec.decoder import SuryaTableRecDecoder, SuryaTableRecTextEncoder
from surya.model.table_rec.encoderdecoder import TableRecEncoderDecoderModel
from surya.settings import settings


def load_model(checkpoint=settings.TABLE_REC_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE):

    config = SuryaTableRecConfig.from_pretrained(checkpoint)
    decoder_config = config.decoder
    decoder = SuryaTableRecDecoderConfig(**decoder_config)
    config.decoder = decoder

    encoder_config = config.encoder
    encoder = DonutSwinTableRecConfig(**encoder_config)
    config.encoder = encoder

    text_encoder_config = config.text_encoder
    text_encoder = SuryaTableRecTextEncoderConfig(**text_encoder_config)
    config.text_encoder = text_encoder

    model = TableRecEncoderDecoderModel.from_pretrained(checkpoint, config=config, torch_dtype=dtype)

    assert isinstance(model.decoder, SuryaTableRecDecoder)
    assert isinstance(model.encoder, DonutSwinModel)
    assert isinstance(model.text_encoder, SuryaTableRecTextEncoder)

    model = model.to(device)
    model = model.eval()

    print(f"Loaded recognition model {checkpoint} on device {device} with dtype {dtype}")
    return model