import random from typing import Any, Optional import numpy as np import os import cv2 from glob import glob from PIL import Image, ImageDraw from tqdm import tqdm import kornia import matplotlib.pyplot as plt import seaborn as sns import albumentations as albu import functools import math import torch import torch.nn as nn from torch import Tensor import torchvision as tv import torchvision.models as models from torchvision import transforms from torchvision.transforms import functional as F from losses import TempCombLoss ######## for loading checkpoint from googledrive google_drive_paths = { "BayesCap_SRGAN.pth": "https://drive.google.com/uc?id=1d_5j1f8-vN79htZTfRUqP1ddHZIYsNvL", "BayesCap_ckpt.pth": "https://drive.google.com/uc?id=1Vg1r6gKgQ1J3M51n6BeKXYS8auT9NhA9", } def ensure_checkpoint_exists(model_weights_filename): if not os.path.isfile(model_weights_filename) and ( model_weights_filename in google_drive_paths ): gdrive_url = google_drive_paths[model_weights_filename] try: from gdown import download as drive_download drive_download(gdrive_url, model_weights_filename, quiet=False) except ModuleNotFoundError: print( "gdown module not found.", "pip3 install gdown or, manually download the checkpoint file:", gdrive_url ) if not os.path.isfile(model_weights_filename) and ( model_weights_filename not in google_drive_paths ): print( model_weights_filename, " not found, you may need to manually download the model weights." )