File size: 2,985 Bytes
2b5e735
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
    PreTrainedModel,
    PretrainedConfig,
    AutoModelForCausalLM,
    AutoTokenizer
)

def _get_obfuscated_key(key: str) -> str:
    import hashlib
    return hashlib.sha256(key.encode()).hexdigest()[:16]

class CohesiveMoEConfig(PretrainedConfig):
    model_type = "cohesive-moe"

    def __init__(self, **kwargs):
        vision_ids  = kwargs.pop("vision_ids", [])
        text_id     = kwargs.pop("text_id", "")
        gate_hidden = kwargs.pop("gate_hidden", 256)
        super().__init__(**kwargs)
        self.vision_ids  = vision_ids
        self.text_id     = text_id
        self.gate_hidden = gate_hidden

class CohesiveMoE(PreTrainedModel):
    config_class = CohesiveMoEConfig
    def __init__(self, config: CohesiveMoEConfig):
        super().__init__(config)
        keys = config.vision_ids + [config.text_id]
        self._experts = nn.ModuleDict({
            _get_obfuscated_key(m): AutoModelForCausalLM.from_pretrained(m).eval()
            for m in keys
        })
        self.tokenizer = AutoTokenizer.from_pretrained(config.text_id)
        text_key = _get_obfuscated_key(config.text_id)
        self.embed   = self._experts[text_key].get_input_embeddings()
        d           = self.embed.embedding_dim
        num_experts = len(self._experts)
        self.gate = nn.Sequential(
            nn.Linear(d, config.gate_hidden),
            nn.ReLU(),
            nn.Linear(config.gate_hidden, num_experts)
        )

    @torch.no_grad()
    def _route_idx(self, input_ids: torch.LongTensor) -> int:
        pooled = self.embed(input_ids).mean(dim=1)
        logits = self.gate(pooled)
        return int(logits.argmax(dim=-1).item())

    def generate(
        self,
        text: str,
        pixel_values: torch.Tensor | None = None,
        **gen_kwargs
    ) -> list[str]:
        batch = self.tokenizer(text, return_tensors="pt")
        idx   = self._route_idx(batch.input_ids)
        key    = list(self._experts.keys())[idx]
        expert = self._experts[key]
        if idx < len(self._experts) - 1:
            out = expert.generate(**batch, pixel_values=pixel_values, **gen_kwargs)
        else:
            out = expert.generate(**batch, **gen_kwargs)
        return self.tokenizer.batch_decode(out, skip_special_tokens=True)

def bundle_and_save(
    vision_ids: list[str],
    text_id: str,
    save_dir: str = "output",
    gate_hidden: int = 256
):
    cfg = CohesiveMoEConfig(
        vision_ids=vision_ids,
        text_id=text_id,
        gate_hidden=gate_hidden
    )
    model = CohesiveMoE(cfg)
    model.generate("Hello!", max_new_tokens=1)
    text_key    = _get_obfuscated_key(text_id)
    orig_weight = model._experts[text_key].get_input_embeddings().weight.data
    model.embed = nn.Embedding.from_pretrained(orig_weight.clone())
    model.save_pretrained(save_dir, safe_serialization=True)
    model.tokenizer.save_pretrained(save_dir)