File size: 1,500 Bytes
8133f69
 
 
 
 
 
 
 
 
 
 
 
 
14eba99
8133f69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14eba99
8133f69
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

import torch
import os
import transformers

from transformers import Idefics2ForConditionalGeneration
from peft import LoraConfig, get_peft_model
from joint_inference import IdeficsJointInferenceModel

def get_model():
    # Initialize the model
    repo = 'lil-lab/cogen'
    checkpoint = "HuggingFaceM4/idefics2-8b"
    model = Idefics2ForConditionalGeneration.from_pretrained(checkpoint, torch_dtype=torch.bfloat16)

    # Add LoRA adapters
    target_modules=r'(.*(vision_model|modality_projection|perceiver_resampler).*(out_proj|fc1|fc2|down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$)|(.*(k_proj|q_proj|v_proj).*$)'
    lora_config = LoraConfig(
        r=16, lora_alpha=8,
        lora_dropout=0.1,
        target_modules=target_modules,
        init_lora_weights="gaussian"
    )
    model = get_peft_model(model, lora_config, adapter_name="initial")
    model.load_adapter(repo, "initial", revision="r0_full")

    # Add other adapter
    new_targets = set()
    for n, p in model.named_parameters():
        if 'lora' in n:
            new_targets.add(n[17:n.find('lora')-1])
    new_targets = list(new_targets)

    lora_config = LoraConfig(
        r=16, lora_alpha=8,
        lora_dropout=0.1,
        target_modules=new_targets,
        init_lora_weights="gaussian"
    )
    model.add_adapter('final', lora_config)
    model.load_adapter(repo, "final", revision="r3_full")
    model = IdeficsJointInferenceModel(0.5, 0, model=model)
    model.eval()
    
    return model