|
""" |
|
Model Factory for loading a checkpoint from gcloud, then initializing the model and the configuration |
|
""" |
|
import os |
|
import requests |
|
import torch |
|
import tqdm |
|
|
|
from cwm.model import model_pretrain |
|
|
|
GCLOUD_BUCKET_NAME = "stanford_neuroai_models" |
|
GCLOUD_URL_NAME = "https://storage.googleapis.com/stanford_neuroai_models" |
|
CACHE_PATH = f"{os.getenv('CACHE')}/stanford_neuroai_models" if os.getenv('CACHE') is not None else ".cache/stanford_neuroai_models" |
|
_model_catalogue ={ |
|
"vitb_8x8patch_3frames": { |
|
"path": "cwm/3frame_cwm_8x8.pth", |
|
"init_fn": model_pretrain.vitb_8x8patch_3frames, |
|
}, |
|
"vitb_4x4patch_2frames": { |
|
"path": "cwm/2frame_cwm_4x4.pth", |
|
"init_fn": model_pretrain.vitb_4x4patch_2frames, |
|
}, |
|
|
|
"vitb_8x8patch_2frames": { |
|
"path": "cwm/2frame_cwm_8x8.pth", |
|
"init_fn": model_pretrain.vitb_8x8patch_2frames, |
|
}, |
|
|
|
} |
|
|
|
|
|
class ModelFactory: |
|
|
|
def __init__(self, bucket_name: str = GCLOUD_BUCKET_NAME): |
|
self.bucket_name = bucket_name |
|
|
|
def get_catalog(self): |
|
""" |
|
Get the list of available models |
|
""" |
|
|
|
return _model_catalogue.keys() |
|
|
|
def load_model(self, model_name: str, force_download=False): |
|
""" |
|
Load the model given the name |
|
|
|
Args: |
|
model_name: str |
|
Name of the model to load |
|
force_download: bool (optional) |
|
Whether to force the download of the freshest weights from gcloud |
|
|
|
Returns: |
|
model: torch.nn.Module |
|
Model initialized from the checkpoint |
|
""" |
|
|
|
checkpoint_path = os.path.join(CACHE_PATH, _model_catalogue[model_name]["path"]) |
|
|
|
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) |
|
|
|
gcloud_url = os.path.join(GCLOUD_URL_NAME, _model_catalogue[model_name]['path']) |
|
|
|
response = requests.get(gcloud_url, stream=True) |
|
total_size_in_bytes= int(response.headers.get('content-length', 0)) |
|
block_size = 1024 |
|
|
|
|
|
if force_download or not os.path.exists(checkpoint_path): |
|
progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) |
|
|
|
print(f"Saving model to cache: {CACHE_PATH}") |
|
with open(checkpoint_path, 'wb') as file: |
|
for data in response.iter_content(block_size): |
|
progress_bar.update(len(data)) |
|
file.write(data) |
|
|
|
|
|
print("checkpoint_path", checkpoint_path) |
|
ckpt = torch.load(checkpoint_path, map_location="cpu") |
|
|
|
|
|
model = _model_catalogue[model_name]["init_fn"]() |
|
|
|
|
|
|
|
model.load_state_dict(ckpt['model'], strict=True) |
|
print('Model loaded successfully') |
|
|
|
return model |
|
|
|
model_factory = ModelFactory() |