|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers import ( |
|
PreTrainedModel, |
|
PretrainedConfig, |
|
AutoModelForCausalLM, |
|
AutoTokenizer |
|
) |
|
|
|
def _get_obfuscated_key(key: str) -> str: |
|
import hashlib |
|
return hashlib.sha256(key.encode()).hexdigest()[:16] |
|
|
|
class CohesiveMoEConfig(PretrainedConfig): |
|
model_type = "cohesive-moe" |
|
|
|
def __init__(self, **kwargs): |
|
vision_ids = kwargs.pop("vision_ids", []) |
|
text_id = kwargs.pop("text_id", "") |
|
gate_hidden = kwargs.pop("gate_hidden", 256) |
|
super().__init__(**kwargs) |
|
self.vision_ids = vision_ids |
|
self.text_id = text_id |
|
self.gate_hidden = gate_hidden |
|
|
|
class CohesiveMoE(PreTrainedModel): |
|
config_class = CohesiveMoEConfig |
|
def __init__(self, config: CohesiveMoEConfig): |
|
super().__init__(config) |
|
keys = config.vision_ids + [config.text_id] |
|
self._experts = nn.ModuleDict({ |
|
_get_obfuscated_key(m): AutoModelForCausalLM.from_pretrained(m).eval() |
|
for m in keys |
|
}) |
|
self.tokenizer = AutoTokenizer.from_pretrained(config.text_id) |
|
text_key = _get_obfuscated_key(config.text_id) |
|
self.embed = self._experts[text_key].get_input_embeddings() |
|
d = self.embed.embedding_dim |
|
num_experts = len(self._experts) |
|
self.gate = nn.Sequential( |
|
nn.Linear(d, config.gate_hidden), |
|
nn.ReLU(), |
|
nn.Linear(config.gate_hidden, num_experts) |
|
) |
|
|
|
@torch.no_grad() |
|
def _route_idx(self, input_ids: torch.LongTensor) -> int: |
|
pooled = self.embed(input_ids).mean(dim=1) |
|
logits = self.gate(pooled) |
|
return int(logits.argmax(dim=-1).item()) |
|
|
|
def generate( |
|
self, |
|
text: str, |
|
pixel_values: torch.Tensor | None = None, |
|
**gen_kwargs |
|
) -> list[str]: |
|
batch = self.tokenizer(text, return_tensors="pt") |
|
idx = self._route_idx(batch.input_ids) |
|
key = list(self._experts.keys())[idx] |
|
expert = self._experts[key] |
|
if idx < len(self._experts) - 1: |
|
out = expert.generate(**batch, pixel_values=pixel_values, **gen_kwargs) |
|
else: |
|
out = expert.generate(**batch, **gen_kwargs) |
|
return self.tokenizer.batch_decode(out, skip_special_tokens=True) |
|
|
|
def bundle_and_save( |
|
vision_ids: list[str], |
|
text_id: str, |
|
save_dir: str = "output", |
|
gate_hidden: int = 256 |
|
): |
|
cfg = CohesiveMoEConfig( |
|
vision_ids=vision_ids, |
|
text_id=text_id, |
|
gate_hidden=gate_hidden |
|
) |
|
model = CohesiveMoE(cfg) |
|
model.generate("Hello!", max_new_tokens=1) |
|
text_key = _get_obfuscated_key(text_id) |
|
orig_weight = model._experts[text_key].get_input_embeddings().weight.data |
|
model.embed = nn.Embedding.from_pretrained(orig_weight.clone()) |
|
model.save_pretrained(save_dir, safe_serialization=True) |
|
model.tokenizer.save_pretrained(save_dir) |
|
|