Last commit not found
import os | |
while True: | |
try: | |
import cv2 | |
except ImportError: | |
print("Package cv2 not found. Attepting installation.") | |
os.system("pip install -U opencv-python &> /dev/null") | |
continue | |
break | |
import os, cv2, time, math | |
print("=> Loading libraries...") | |
start = time.time() | |
import requests, torch, argparse | |
import gradio as gr | |
from torchvision import transforms | |
from datasets import load_dataset | |
from timm.data import create_transform | |
from timm.models import create_model, load_checkpoint | |
from pytorch_grad_cam import GradCAM | |
from pytorch_grad_cam.utils.image import show_cam_on_image | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--local", action='store_true') | |
args = parser.parse_args() | |
if not args.local: | |
print("=> Logging into huggingface...") | |
from huggingface_hub import login | |
login(token=os.environ["HF_TOKEN"]) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"=> Libraries loaded in {time.time()- start:.2f} sec(s).") | |
print("=> Loading model...") | |
start = time.time() | |
size = "b" | |
img_size = 224 | |
crop_pct = 0.9 | |
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | |
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) | |
model = create_model(f"tpmlp_{size}").to(device) | |
try: | |
load_checkpoint(model, f"../tpmlp_{size}.pth.tar", True) | |
except FileNotFoundError: | |
load_checkpoint(model, f"tpmlp_{size}.pth.tar", True) | |
model.eval() | |
response = requests.get("https://git.io/JJkYN") | |
labels = response.text.split("\n") | |
augs = create_transform( | |
input_size=(3, 224, 224), | |
is_training=False, | |
use_prefetcher=False, | |
crop_pct=0.9, | |
) | |
scale_size = math.floor(img_size / crop_pct) | |
resize = transforms.Compose([ | |
transforms.Resize(scale_size), | |
transforms.CenterCrop(img_size), | |
transforms.ToTensor() | |
]) | |
normalize = transforms.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN), std=torch.tensor(IMAGENET_DEFAULT_STD)) | |
def transform(img): | |
img = resize(img.convert("RGB")) | |
tensor = normalize(img) | |
return img, tensor | |
def predict(inp): | |
img, inp = transform(inp) | |
inp = inp.unsqueeze(0) | |
with GradCAM(model=model, target_layers=[model.layers[3]], use_cuda=device=="cuda") as cam: | |
grayscale_cam, probs = cam(input_tensor=inp, aug_smooth=False, eigen_smooth=False, return_probs=True) | |
# Here grayscale_cam has only one image in the batch | |
grayscale_cam = grayscale_cam[0, :] | |
probs = probs[0, :] | |
cam_image = show_cam_on_image(img.permute(1, 2, 0).detach().cpu().numpy(), grayscale_cam, use_rgb=True, image_weight=0.5, colormap=cv2.COLORMAP_TWILIGHT_SHIFTED) | |
confidences = {labels[i]: float(probs[i]) for i in range(1000)} | |
return confidences, cam_image | |
print(f"=> Model (tpmlp_{size}) loaded in {time.time()- start:.2f} sec(s).") | |
base = "../example-imgs" if args.local else "." | |
print("=> Loading examples.") | |
indices = [ | |
0, # Coucal | |
2, # Volcano | |
7, # Sombrero | |
9, # Balance beam | |
10, # Sulphur-crested cockatoo | |
11, # Shower cap | |
12, # Petri dish INCORRECTLY CLASSIFIED as lens | |
14, # Angora rabbit | |
] | |
ds = load_dataset("imagenet-1k", split="validation", streaming=True) | |
examples = []; idx = 0 | |
start = time.time() | |
for data in ds: | |
if idx == indices: | |
data['image'].save(f"{base}/{idx}.png") | |
idx += 1 | |
if idx == max(indices): | |
break | |
del ds | |
print(f"=> Examples loaded in {time.time()- start:.2f} sec(s).") | |
# demo = gr.Interface( | |
# fn=predict, | |
# inputs=gr.inputs.Image(type="pil"), | |
# outputs=[gr.outputs.Label(num_top_classes=4), gr.outputs.Image(type="numpy")], | |
# examples=[f"../example-imgs/{idx}.png" for idx in indices], | |
# ) | |
with gr.Blocks(theme=gr.themes.Monochrome(font=[gr.themes.GoogleFont("DM Sans"), "sans-serif"])) as demo: | |
gr.HTML(""" | |
<h1 align="center">Interactive Demo</h1> | |
<h2 align="center">CS-Mixer: A Cross-Scale Vision MLP Model with Spatial–Channel Mixing</h2> | |
<br><br> | |
""") | |
with gr.Row(): | |
input_image = gr.Image(type="pil", min_width=300, label="Input Image") | |
softmax = gr.Label(num_top_classes=4, min_width=200, label="Model Predictions") | |
grad_cam = gr.Image(type="numpy", min_width=300, label="Grad-CAM") | |
with gr.Row(): | |
gr.Button("Predict").click(fn=predict, inputs=input_image, outputs=[softmax, grad_cam]) | |
gr.ClearButton(input_image) | |
with gr.Row(): | |
gr.Examples([f"{base}/{idx}.png" for idx in indices], inputs=input_image, outputs=[softmax, grad_cam], fn=predict, run_on_click=True) | |
demo.launch( | |
share=False, debug=False, allowed_paths=[f"{base}"], server_name="0.0.0.0", # ssl_verify=False, | |
server_port=8000, # ssl_certfile="/workspace/openssl/cert.pem", ssl_keyfile="/workspace/openssl/key.pem" | |
) |