Spaces:
Sleeping
Sleeping
import torch | |
import os | |
from typing import Tuple, Optional, Any, Union | |
import json | |
from .utils import tokenize, transform | |
from .prepare import prepare | |
from .text_encoder import CLIPTextEncoder | |
from .image_encoder import ModifiedResNet, VisionTransformer | |
from .model import CLIP | |
curr_dir = os.path.dirname(os.path.abspath(__file__)) | |
clip_model_names = [ | |
"clip_vit_b_16", | |
"clip_vit_l_14", | |
] | |
clip_image_encoder_names = [f"clip_image_encoder_{name[5:]}" for name in clip_model_names] | |
clip_text_encoder_names = [f"clip_text_encoder_{name[5:]}" for name in clip_model_names] | |
for name in clip_model_names + clip_image_encoder_names + clip_text_encoder_names: | |
model_weights_path = os.path.join(curr_dir, "weights", f"{name}.pth") | |
model_config_path = os.path.join(curr_dir, "configs", f"{name}.json") | |
if not os.path.exists(os.path.join(curr_dir, "weights", f"{name}.pth")) or not os.path.exists(os.path.join(curr_dir, "configs", f"{name}.json")): | |
prepare() | |
break | |
for name in clip_model_names + clip_image_encoder_names + clip_text_encoder_names: | |
assert os.path.exists(os.path.join(curr_dir, "weights", f"{name}.pth")), f"Missing {name}.pth in weights folder. Please run models/clip/prepare.py to download the weights." | |
assert os.path.exists(os.path.join(curr_dir, "configs", f"{name}.json")), f"Missing {name}.json in configs folder. Please run models/clip/prepare.py to download the configs." | |
def _clip(name: str, input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP: | |
with open(os.path.join(curr_dir, "configs", f"clip_{name}.json"), "r") as f: | |
config = json.load(f) | |
model = CLIP( | |
embed_dim=config["embed_dim"], | |
# vision | |
image_resolution=config["image_resolution"], | |
vision_layers=config["vision_layers"], | |
vision_width=config["vision_width"], | |
vision_patch_size=config["vision_patch_size"], | |
# text | |
context_length=config["context_length"], | |
vocab_size=config["vocab_size"], | |
transformer_width=config["transformer_width"], | |
transformer_heads=config["transformer_heads"], | |
transformer_layers=config["transformer_layers"] | |
) | |
state_dict = torch.load(os.path.join(curr_dir, "weights", f"clip_{name}.pth"), map_location="cpu") | |
model.load_state_dict(state_dict, strict=True) | |
if input_size is not None: | |
input_size = (input_size, input_size) if isinstance(input_size, int) else input_size | |
if name.startswith("vit"): | |
model.visual.adjust_pos_embed(*input_size) | |
return model | |
def _resnet( | |
name: str, | |
reduction: int = 32, | |
features_only: bool = False, | |
out_indices: Optional[Tuple[int, ...]] = None, | |
**kwargs: Any | |
) -> ModifiedResNet: | |
with open(os.path.join(curr_dir, "configs", f"clip_image_encoder_{name}.json"), "r") as f: | |
config = json.load(f) | |
model = ModifiedResNet( | |
layers=config["vision_layers"], | |
output_dim=config["embed_dim"], | |
input_resolution=config["image_resolution"], | |
width=config["vision_width"], | |
heads=config["vision_heads"], | |
features_only=features_only, | |
out_indices=out_indices, | |
reduction=reduction | |
) | |
state_dict = torch.load(os.path.join(curr_dir, "weights", f"clip_image_encoder_{name}.pth"), map_location="cpu") | |
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) | |
if len(missing_keys) > 0 or len(unexpected_keys) > 0: | |
print(f"Missing keys: {missing_keys}") | |
print(f"Unexpected keys: {unexpected_keys}") | |
else: | |
print(f"All keys matched successfully.") | |
return model | |
def _vit(name: str, features_only: bool = False, input_size: Optional[Union[int, Tuple[int, int]]] = None, **kwargs: Any) -> VisionTransformer: | |
with open(os.path.join(curr_dir, "configs", f"clip_image_encoder_{name}.json"), "r") as f: | |
config = json.load(f) | |
model = VisionTransformer( | |
input_resolution=config["image_resolution"], | |
patch_size=config["vision_patch_size"], | |
output_dim=config["embed_dim"], | |
width=config["vision_width"], | |
layers=config["vision_layers"], | |
heads=config["vision_heads"], | |
features_only=features_only | |
) | |
state_dict = torch.load(os.path.join(curr_dir, "weights", f"clip_image_encoder_{name}.pth"), map_location="cpu") | |
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) | |
if len(missing_keys) > 0 or len(unexpected_keys) > 0: | |
print(f"Missing keys: {missing_keys}") | |
print(f"Unexpected keys: {unexpected_keys}") | |
else: | |
print(f"All keys matched successfully.") | |
if input_size is not None: | |
input_size = (input_size, input_size) if isinstance(input_size, int) else input_size | |
model.adjust_pos_embed(*input_size) | |
return model | |
def _text_encoder(name: str) -> CLIPTextEncoder: | |
with open(os.path.join(curr_dir, "configs", f"clip_text_encoder_{name}.json"), "r") as f: | |
config = json.load(f) | |
model = CLIPTextEncoder( | |
embed_dim=config["embed_dim"], | |
context_length=config["context_length"], | |
vocab_size=config["vocab_size"], | |
transformer_width=config["transformer_width"], | |
transformer_heads=config["transformer_heads"], | |
transformer_layers=config["transformer_layers"] | |
) | |
state_dict = torch.load(os.path.join(curr_dir, "weights", f"clip_text_encoder_{name}.pth"), map_location="cpu") | |
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) | |
if len(missing_keys) > 0 or len(unexpected_keys) > 0: | |
print(f"Missing keys: {missing_keys}") | |
print(f"Unexpected keys: {unexpected_keys}") | |
else: | |
print(f"All keys matched successfully.") | |
return model | |
# CLIP models | |
def resnet50_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP: | |
return _clip("resnet50", input_size) | |
def resnet101_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP: | |
return _clip("resnet101", input_size) | |
def resnet50x4_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP: | |
return _clip("resnet50x4", input_size) | |
def resnet50x16_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP: | |
return _clip("resnet50x16", input_size) | |
def resnet50x64_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP: | |
return _clip("resnet50x64", input_size) | |
def vit_b_32_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP: | |
return _clip("vit_b_32", input_size) | |
def vit_b_16_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP: | |
return _clip("vit_b_16", input_size) | |
def vit_l_14_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP: | |
return _clip("vit_l_14", input_size) | |
def vit_l_14_336px_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP: | |
return _clip("vit_l_14_336px", input_size) | |
# CLIP image encoders | |
def resnet50_img(features_only: bool = False, out_indices: Optional[Tuple[int, ...]] = None, **kwargs: Any) -> ModifiedResNet: | |
return _resnet("resnet50", features_only=features_only, out_indices=out_indices, **kwargs) | |
def resnet101_img(features_only: bool = False, out_indices: Optional[Tuple[int, ...]] = None, **kwargs: Any) -> ModifiedResNet: | |
return _resnet("resnet101", features_only=features_only, out_indices=out_indices, **kwargs) | |
def resnet50x4_img(features_only: bool = False, out_indices: Optional[Tuple[int, ...]] = None, **kwargs: Any) -> ModifiedResNet: | |
return _resnet("resnet50x4", features_only=features_only, out_indices=out_indices, **kwargs) | |
def resnet50x16_img(features_only: bool = False, out_indices: Optional[Tuple[int, ...]] = None, **kwargs: Any) -> ModifiedResNet: | |
return _resnet("resnet50x16", features_only=features_only, out_indices=out_indices, **kwargs) | |
def resnet50x64_img(features_only: bool = False, out_indices: Optional[Tuple[int, ...]] = None, **kwargs: Any) -> ModifiedResNet: | |
return _resnet("resnet50x64", features_only=features_only, out_indices=out_indices, **kwargs) | |
def vit_b_32_img(features_only: bool = False, input_size: Optional[Union[int, Tuple[int, int]]] = None, **kwargs: Any) -> VisionTransformer: | |
return _vit("vit_b_32", features_only=features_only, input_size=input_size, **kwargs) | |
def vit_b_16_img(features_only: bool = False, input_size: Optional[Union[int, Tuple[int, int]]] = None, **kwargs: Any) -> VisionTransformer: | |
return _vit("vit_b_16", features_only=features_only, input_size=input_size, **kwargs) | |
def vit_l_14_img(features_only: bool = False, input_size: Optional[Union[int, Tuple[int, int]]] = None, **kwargs: Any) -> VisionTransformer: | |
return _vit("vit_l_14", features_only=features_only, input_size=input_size, **kwargs) | |
def vit_l_14_336px_img(features_only: bool = False, input_size: Optional[Union[int, Tuple[int, int]]] = None, **kwargs: Any) -> VisionTransformer: | |
return _vit("vit_l_14_336px", features_only=features_only, input_size=input_size, **kwargs) | |
# CLIP text encoders | |
def resnet50_txt() -> CLIPTextEncoder: | |
return _text_encoder("resnet50") | |
def resnet101_txt() -> CLIPTextEncoder: | |
return _text_encoder("resnet101") | |
def resnet50x4_txt() -> CLIPTextEncoder: | |
return _text_encoder("resnet50x4") | |
def resnet50x16_txt() -> CLIPTextEncoder: | |
return _text_encoder("resnet50x16") | |
def resnet50x64_txt() -> CLIPTextEncoder: | |
return _text_encoder("resnet50x64") | |
def vit_b_32_txt() -> CLIPTextEncoder: | |
return _text_encoder("vit_b_32") | |
def vit_b_16_txt() -> CLIPTextEncoder: | |
return _text_encoder("vit_b_16") | |
def vit_l_14_txt() -> CLIPTextEncoder: | |
return _text_encoder("vit_l_14") | |
def vit_l_14_336px_txt() -> CLIPTextEncoder: | |
return _text_encoder("vit_l_14_336px") | |
__all__ = [ | |
# utils | |
"tokenize", | |
"transform", | |
# clip image encoders | |
"vit_b_16_img", | |
"vit_l_14_img", | |
# clip text encoders | |
"vit_b_16_txt", | |
"vit_l_14_txt", | |
] | |