model1 / llava /model /utils.py
multitensor's picture
Upload folder using huggingface_hub
bbfa6f6 verified
import torch
import torch.nn as nn
from transformers import AutoConfig
def auto_upgrade(config):
cfg = AutoConfig.from_pretrained(config)
if 'llava' in config and 'llava' not in cfg.model_type:
assert cfg.model_type == 'llama'
print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
if confirm.lower() in ["y", "yes"]:
print("Upgrading checkpoint...")
assert len(cfg.architectures) == 1
setattr(cfg.__class__, "model_type", "llava")
cfg.architectures[0] = 'LlavaLlamaForCausalLM'
cfg.save_pretrained(config)
print("Checkpoint upgraded.")
else:
print("Checkpoint upgrade aborted.")
exit(1)
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
with torch.cuda.amp.autocast(dtype=torch.float32):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)