|
import torch |
|
import torch.nn as nn |
|
from transformers import AutoConfig |
|
|
|
|
|
def auto_upgrade(config): |
|
cfg = AutoConfig.from_pretrained(config) |
|
if 'llava' in config and 'llava' not in cfg.model_type: |
|
assert cfg.model_type == 'llama' |
|
print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") |
|
print("You must upgrade the checkpoint to the new code base (this can be done automatically).") |
|
confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") |
|
if confirm.lower() in ["y", "yes"]: |
|
print("Upgrading checkpoint...") |
|
assert len(cfg.architectures) == 1 |
|
setattr(cfg.__class__, "model_type", "llava") |
|
cfg.architectures[0] = 'LlavaLlamaForCausalLM' |
|
cfg.save_pretrained(config) |
|
print("Checkpoint upgraded.") |
|
else: |
|
print("Checkpoint upgrade aborted.") |
|
exit(1) |
|
|
|
|
|
class LayerNorm(nn.LayerNorm): |
|
"""Subclass torch's LayerNorm to handle fp16.""" |
|
|
|
def forward(self, x: torch.Tensor): |
|
with torch.cuda.amp.autocast(dtype=torch.float32): |
|
orig_type = x.dtype |
|
ret = super().forward(x.type(torch.float32)) |
|
return ret.type(orig_type) |