File size: 1,832 Bytes
54a7220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import logging

from omegaconf import OmegaConf
from lavis.models import registry
from lavis.models import load_preprocess

from ldm.util import instantiate_from_config


def load_blip2_model(cfg, is_eval=False, device="cpu"):
    model_cls = registry.get_model_class(cfg.model_name)

    # load preprocess
    default_cfg = OmegaConf.load(model_cls.default_config_path(cfg.model_type))
    default_cfg.model.pretrained = cfg.pretrained

    if default_cfg.model.image_size != cfg.params.img_size:
        default_cfg.model.image_size = cfg.params.img_size
    model = model_cls.from_config(default_cfg.model)
    model.cfg = default_cfg.model

    if is_eval:
        model.eval()

    if default_cfg is not None:
        preprocess_cfg = default_cfg.preprocess
        vis_processors, txt_processors = load_preprocess(preprocess_cfg)
    else:
        vis_processors, txt_processors = None, None
        logging.info(
            f"""No default preprocess for model {name} ({model_type}).
                This can happen if the model is not finetuned on downstream datasets,
                or it is not intended for direct use without finetuning.
            """
        )

    if device == "cpu" or device == torch.device("cpu"):
        model = model.float()

    return model.to(device), vis_processors, txt_processors


def load_qformer_model(cfg):
    blip2_model, vis_processor, txt_processor = load_blip2_model(cfg) 
    q_former = instantiate_from_config(cfg)
    if blip2_model.query_tokens.shape != q_former.query_tokens.shape:
        blip2_model.query_tokens = q_former.query_tokens
    model_name = cfg.params.get('model_name', 'bert-base-uncased')
    if model_name == 'bert-base-uncased':
        q_former.load_state_dict(blip2_model.state_dict(), strict=False)
    return q_former, (vis_processor, txt_processor)