import gradio as gr import torch import matplotlib.pyplot as plt from matplotlib import cm from matplotlib import colors from mpl_toolkits.axes_grid1 import ImageGrid from torchvision import transforms import fire_network from PIL import Image # Possible Scales for multiscale inference scales = [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25] device = 'cpu' # Load net state = torch.load('fire.pth', map_location='cpu') state['net_params']['pretrained'] = None # no need for imagenet pretrained model net = fire_network.init_network(**state['net_params']).to(device) net.load_state_dict(state['state_dict']) transform = transforms.Compose([ transforms.Resize(1024), transforms.ToTensor(), transforms.Normalize(**dict(zip(["mean", "std"], net.runtime['mean_std']))) ]) # which sf sf_idx_ = [55, 14, 5, 4, 52, 57, 40, 9] col = plt.get_cmap('tab10') def generate_matching_superfeatures(im1, im2, scale_id=6, threshold=50): im1_tensor = transform(im1).unsqueeze(0) im2_tensor = transform(im2).unsqueeze(0) # im1_cv = cv2.imread(im1) # im2_cv = cv2.imread(im2) # extract features with torch.no_grad(): output1 = net.get_superfeatures(im1_tensor.to(device), scales=[scale_id]) feats1 = output1[0][0] attns1 = output1[1][0] strenghts1 = output1[2][0] output2 = net.get_superfeatures(im2_tensor.to(device), scales=[scale_id]) feats2 = output2[0][0] attns2 = output2[1][0] strenghts2 = output2[2][0] print(feats1.shape, feats2.shape) print(attns1.shape, attns2.shape) print(strenghts1.shape, strenghts2.shape) # Store all binary SF att maps to show them all at once in the end all_att_bin1 = [] all_att_bin2 = [] for n, i in enumerate(sf_idx_): # all_atts[n].append(attn[j][scale_id][0,i,:,:].numpy()) att_heat = np.array(attns1[0,i,:,:].numpy(), dtype=np.float32) att_heat = np.uint8(att_heat / np.max(att_heat[:]) * 255.0) att_heat_bin = np.where(att_heat>threshold, 255, 0) all_att_bin1.append(att_heat_bin) att_heat = np.array(attns2[0,i,:,:].numpy(), dtype=np.float32) att_heat = np.uint8(att_heat / np.max(att_heat[:]) * 255.0) att_heat_bin = np.where(att_heat>threshold, 255, 0) all_att_bin2.append(att_heat_bin) fin_img = [] img1rsz = np.copy(im1) for j, att in enumerate(all_att_bin1): # att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_NEAREST) # att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_CUBIC) # att = cv2.resize(att, imgz[i].shape[:2][::-1]) att = att.resize(im1.size) mask2d = zip(*np.where(att==255)) for m,n in mask2d: col_ = col.colors[j] if j < 7 else col.colors[j+1] if j == 0: col_ = col.colors[9] col_ = 255*np.array(colors.to_rgba(col_))[:3] img1rsz[m,n, :] = col_[::-1] fin_img.append(img1rsz) img2rsz = np.copy(im2) print(img2rsz.size) for j, att in enumerate(all_att_bin2): # att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_NEAREST) # att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_CUBIC) # att = cv2.resize(att, imgz[i].shape[:2][::-1]) att = att.resize(im2.size) print('att:', att.size) mask2d = zip(*np.where(att==255)) for m,n in mask2d: col_ = col.colors[j] if j < 7 else col.colors[j+1] if j == 0: col_ = col.colors[9] col_ = 255*np.array(colors.to_rgba(col_))[:3] img2rsz[m,n, :] = col_[::-1] fin_img.append(img2rsz) fig = plt.figure(figsize=(12,25)) grid = ImageGrid(fig, 111, nrows_ncols=(2, 1), axes_pad=0.1) for ax, img in zip(grid, fin_img): ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) ax.axis('scaled') ax.axis('off') plt.tight_layout() fig.suptitle("Matching SFs", fontsize=16) return fig # GRADIO APP title = "Visualizing Super-features" description = "TBD" article = "
" iface = gr.Interface( fn=generate_matching_superfeatures, inputs=[ gr.inputs.Image(shape=(240, 240), type="pil"), gr.inputs.Image(shape=(240, 240), type="pil"), gr.inputs.Slider(minimum=1, maximum=7, step=1, default=2, label="Scale"), gr.inputs.Slider(minimum=1, maximum=255, step=25, default=50, label="Binarizatio Threshold")], outputs="plot", enable_queue=True, title=title, description=description, article=article, examples=[["chateau_1.png", "chateau_2.png", 6, 50]], ) iface.launch()