cohesivemoetest / hetero_moe_bundle.py
rmdhirr's picture
Upload folder using huggingface_hub
2b5e735 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
PreTrainedModel,
PretrainedConfig,
AutoModelForCausalLM,
AutoTokenizer
)
def _get_obfuscated_key(key: str) -> str:
import hashlib
return hashlib.sha256(key.encode()).hexdigest()[:16]
class CohesiveMoEConfig(PretrainedConfig):
model_type = "cohesive-moe"
def __init__(self, **kwargs):
vision_ids = kwargs.pop("vision_ids", [])
text_id = kwargs.pop("text_id", "")
gate_hidden = kwargs.pop("gate_hidden", 256)
super().__init__(**kwargs)
self.vision_ids = vision_ids
self.text_id = text_id
self.gate_hidden = gate_hidden
class CohesiveMoE(PreTrainedModel):
config_class = CohesiveMoEConfig
def __init__(self, config: CohesiveMoEConfig):
super().__init__(config)
keys = config.vision_ids + [config.text_id]
self._experts = nn.ModuleDict({
_get_obfuscated_key(m): AutoModelForCausalLM.from_pretrained(m).eval()
for m in keys
})
self.tokenizer = AutoTokenizer.from_pretrained(config.text_id)
text_key = _get_obfuscated_key(config.text_id)
self.embed = self._experts[text_key].get_input_embeddings()
d = self.embed.embedding_dim
num_experts = len(self._experts)
self.gate = nn.Sequential(
nn.Linear(d, config.gate_hidden),
nn.ReLU(),
nn.Linear(config.gate_hidden, num_experts)
)
@torch.no_grad()
def _route_idx(self, input_ids: torch.LongTensor) -> int:
pooled = self.embed(input_ids).mean(dim=1)
logits = self.gate(pooled)
return int(logits.argmax(dim=-1).item())
def generate(
self,
text: str,
pixel_values: torch.Tensor | None = None,
**gen_kwargs
) -> list[str]:
batch = self.tokenizer(text, return_tensors="pt")
idx = self._route_idx(batch.input_ids)
key = list(self._experts.keys())[idx]
expert = self._experts[key]
if idx < len(self._experts) - 1:
out = expert.generate(**batch, pixel_values=pixel_values, **gen_kwargs)
else:
out = expert.generate(**batch, **gen_kwargs)
return self.tokenizer.batch_decode(out, skip_special_tokens=True)
def bundle_and_save(
vision_ids: list[str],
text_id: str,
save_dir: str = "output",
gate_hidden: int = 256
):
cfg = CohesiveMoEConfig(
vision_ids=vision_ids,
text_id=text_id,
gate_hidden=gate_hidden
)
model = CohesiveMoE(cfg)
model.generate("Hello!", max_new_tokens=1)
text_key = _get_obfuscated_key(text_id)
orig_weight = model._experts[text_key].get_input_embeddings().weight.data
model.embed = nn.Embedding.from_pretrained(orig_weight.clone())
model.save_pretrained(save_dir, safe_serialization=True)
model.tokenizer.save_pretrained(save_dir)