import math import os import sys import gradio as gr import matplotlib.pyplot as plt import numpy as np from PIL import Image import torch import torchvision os.system("git clone https://github.com/xplip/pixel.git") sys.path.append('./pixel') from transformers import set_seed from pixel.src.pixel import ( PIXELConfig, PIXELForPreTraining, SpanMaskingGenerator, PyGameTextRenderer, get_transforms, resize_model_embeddings, truncate_decoder_pos_embeddings, get_attention_mask ) model_name_or_path = "Team-PIXEL/pixel-base" max_seq_length = 529 text_renderer = PyGameTextRenderer.from_pretrained(model_name_or_path, max_seq_length=max_seq_length) config = PIXELConfig.from_pretrained(model_name_or_path) model = PIXELForPreTraining.from_pretrained(model_name_or_path, config=config) def clip(x: torch.Tensor): x = torch.einsum("chw->hwc", x) x = torch.clip(x * 255, 0, 255) x = torch.einsum("hwc->chw", x) return x def get_image(img: torch.Tensor, do_clip: bool = True): if do_clip: img = clip(img) img = torchvision.utils.make_grid(img, normalize=True) image = Image.fromarray( img.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy() ) return image def inference(text: str, mask_ratio: float = 0.25, max_span_length: int = 6, seed: int = 42): config.update({"mask_ratio": mask_ratio}) resize_model_embeddings(model, max_seq_length) truncate_decoder_pos_embeddings(model, max_seq_length) set_seed(seed) transforms = get_transforms( do_resize=True, size=(text_renderer.pixels_per_patch, text_renderer.pixels_per_patch * text_renderer.max_seq_length), ) encoding = text_renderer(text=text) attention_mask = get_attention_mask( num_text_patches=encoding.num_text_patches, seq_length=text_renderer.max_seq_length ) img = transforms(Image.fromarray(encoding.pixel_values)).unsqueeze(0) attention_mask = attention_mask.unsqueeze(0) inputs = {"pixel_values": img.float(), "attention_mask": attention_mask} mask_generator = SpanMaskingGenerator( num_patches=text_renderer.max_seq_length, num_masking_patches=math.ceil(mask_ratio * text_renderer.max_seq_length), max_span_length=max_span_length, spacing="span" ) mask = torch.tensor(mask_generator(num_text_patches=(encoding.num_text_patches + 1))).unsqueeze(0) inputs.update({"patch_mask": mask}) model.eval() with torch.no_grad(): outputs = model(**inputs) predictions = model.unpatchify(outputs["logits"]).detach().cpu().squeeze() mask = outputs["mask"].detach().cpu() mask = mask.unsqueeze(-1).repeat(1, 1, text_renderer.pixels_per_patch ** 2 * 3) mask = model.unpatchify(mask).squeeze() attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, text_renderer.pixels_per_patch ** 2 * 3) attention_mask = model.unpatchify(attention_mask).squeeze() original_img = model.unpatchify(model.patchify(img)).squeeze() im_masked = original_img * (1 - (torch.bitwise_and(mask == 1, attention_mask == 1)).long()) masked_predictions = predictions * mask * attention_mask reconstruction = im_masked + masked_predictions return [get_image(original_img), get_image(im_masked), get_image(masked_predictions, do_clip=False), get_image(reconstruction, do_clip=False)] demo = gr.Blocks() with demo: gr.Markdown("## PIXEL Masked Autoencoding") gr.Markdown("Gradio demo for [PIXEL](https://huggingface.co/Team-PIXEL/pixel-base), introduced in [Language Modelling with Pixels](https://arxiv.org/abs/2207.06991). To use it, simply input your piece of text or click one of the examples to load them. Read more at the links below.") with gr.Row(): with gr.Column(): tb_text = gr.Textbox(label="Text") sl_ratio = gr.Slider(0.01, 1.0, step=0.01, value=0.25, label="Span masking ratio") sl_len = gr.Slider(1, 6, step=1, value=6, label="Masking max span length") sl_seed = gr.Slider(0, 1000, step=1, value=42, label="Random seed") btn = gr.Button("Run") with gr.Column(): out_original = gr.Image(label="Original") out_masked = gr.Image(label="Masked") out_masked_pred = gr.Image(label="Masked Predictions") out_reconstruction = gr.Image(label="Reconstruction") btn.click(fn=inference, inputs=[tb_text, sl_ratio, sl_len, sl_seed], outputs=[out_original, out_masked, out_masked_pred, out_reconstruction]) demo.launch(debug=True)