Spaces:
Runtime error
Runtime error
import random | |
from typing import Any, Optional | |
import numpy as np | |
import os | |
import cv2 | |
from glob import glob | |
from PIL import Image, ImageDraw | |
from tqdm import tqdm | |
import kornia | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import albumentations as albu | |
import functools | |
import math | |
import torch | |
import torch.nn as nn | |
from torch import Tensor | |
import torchvision as tv | |
import torchvision.models as models | |
from torchvision import transforms | |
from torchvision.transforms import functional as F | |
from losses import TempCombLoss | |
######## for loading checkpoint from googledrive | |
google_drive_paths = { | |
"BayesCap_SRGAN.pth": "https://drive.google.com/uc?id=1d_5j1f8-vN79htZTfRUqP1ddHZIYsNvL", | |
"BayesCap_ckpt.pth": "https://drive.google.com/uc?id=1Vg1r6gKgQ1J3M51n6BeKXYS8auT9NhA9", | |
} | |
def ensure_checkpoint_exists(model_weights_filename): | |
if not os.path.isfile(model_weights_filename) and ( | |
model_weights_filename in google_drive_paths | |
): | |
gdrive_url = google_drive_paths[model_weights_filename] | |
try: | |
from gdown import download as drive_download | |
drive_download(gdrive_url, model_weights_filename, quiet=False) | |
except ModuleNotFoundError: | |
print( | |
"gdown module not found.", | |
"pip3 install gdown or, manually download the checkpoint file:", | |
gdrive_url | |
) | |
if not os.path.isfile(model_weights_filename) and ( | |
model_weights_filename not in google_drive_paths | |
): | |
print( | |
model_weights_filename, | |
" not found, you may need to manually download the model weights." | |
) |