import os, cv2, time, math print("=> Loading libraries...") start = time.time() import requests, torch import gradio as gr from torchvision import transforms from datasets import load_dataset from timm.data import create_transform from timm.models import create_model, load_checkpoint from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.image import show_cam_on_image print(f"=> Libraries loaded in {time.time()- start:.2f} sec(s).") print("=> Loading model...") start = time.time() size = "b" img_size = 224 crop_pct = 0.9 IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) model = create_model(f"tpmlp_{size}").cuda() load_checkpoint(model, f"../tpmlp_{size}.pth.tar", True) model.eval() response = requests.get("https://git.io/JJkYN") labels = response.text.split("\n") augs = create_transform( input_size=(3, 224, 224), is_training=False, use_prefetcher=False, crop_pct=0.9, ) scale_size = math.floor(img_size / crop_pct) resize = transforms.Compose([ transforms.Resize(scale_size), transforms.CenterCrop(img_size), transforms.ToTensor() ]) normalize = transforms.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN), std=torch.tensor(IMAGENET_DEFAULT_STD)) def transform(img): img = resize(img.convert("RGB")) tensor = normalize(img) return img, tensor def predict(inp): img, inp = transform(inp) inp = inp.unsqueeze(0) with GradCAM(model=model, target_layers=[model.layers[3]], use_cuda=True) as cam: grayscale_cam, probs = cam(input_tensor=inp, aug_smooth=False, eigen_smooth=False, return_probs=True) # Here grayscale_cam has only one image in the batch grayscale_cam = grayscale_cam[0, :] probs = probs[0, :] cam_image = show_cam_on_image(img.permute(1, 2, 0).detach().cpu().numpy(), grayscale_cam, use_rgb=True, image_weight=0.5, colormap=cv2.COLORMAP_TWILIGHT_SHIFTED) confidences = {labels[i]: float(probs[i]) for i in range(1000)} return confidences, cam_image print(f"=> Model (tpmlp_{size}) loaded in {time.time()- start:.2f} sec(s).") if not os.path.isdir("../example-imgs"): os.mkdir("../example-imgs") print("=> Loading examples.") indices = [ 0, # Coucal 2, # Volcano 7, # Sombrero 9, # Balance beam 10, # Sulphur-crested cockatoo 11, # Shower cap 12, # Petri dish INCORRECTLY CLASSIFIED as lens 14, # Angora rabbit ] ds = load_dataset("imagenet-1k", split="validation", streaming=True) examples = []; idx = 0 start = time.time() for data in ds: if idx == indices: data['image'].save(f"../example-imgs/{idx}.png") idx += 1 if idx == max(indices): break del ds print(f"=> Examples loaded in {time.time()- start:.2f} sec(s).") # demo = gr.Interface( # fn=predict, # inputs=gr.inputs.Image(type="pil"), # outputs=[gr.outputs.Label(num_top_classes=4), gr.outputs.Image(type="numpy")], # examples=[f"../example-imgs/{idx}.png" for idx in indices], # ) with gr.Blocks(theme=gr.themes.Monochrome(font=[gr.themes.GoogleFont("DM Sans"), "sans-serif"])) as demo: gr.HTML("""