BayesCap / utils.py
udion's picture
fixed utils for cuda->cpu
aea9234
raw
history blame
1.66 kB
import random
from typing import Any, Optional
import numpy as np
import os
import cv2
from glob import glob
from PIL import Image, ImageDraw
from tqdm import tqdm
import kornia
import matplotlib.pyplot as plt
import seaborn as sns
import albumentations as albu
import functools
import math
import torch
import torch.nn as nn
from torch import Tensor
import torchvision as tv
import torchvision.models as models
from torchvision import transforms
from torchvision.transforms import functional as F
from losses import TempCombLoss
######## for loading checkpoint from googledrive
google_drive_paths = {
"BayesCap_SRGAN.pth": "https://drive.google.com/uc?id=1d_5j1f8-vN79htZTfRUqP1ddHZIYsNvL",
"BayesCap_ckpt.pth": "https://drive.google.com/uc?id=1Vg1r6gKgQ1J3M51n6BeKXYS8auT9NhA9",
}
def ensure_checkpoint_exists(model_weights_filename):
if not os.path.isfile(model_weights_filename) and (
model_weights_filename in google_drive_paths
):
gdrive_url = google_drive_paths[model_weights_filename]
try:
from gdown import download as drive_download
drive_download(gdrive_url, model_weights_filename, quiet=False)
except ModuleNotFoundError:
print(
"gdown module not found.",
"pip3 install gdown or, manually download the checkpoint file:",
gdrive_url
)
if not os.path.isfile(model_weights_filename) and (
model_weights_filename not in google_drive_paths
):
print(
model_weights_filename,
" not found, you may need to manually download the model weights."
)