Spaces:
Running
Running
import math | |
import os | |
import sys | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from PIL import Image | |
import torch | |
import torchvision | |
os.system("git clone https://github.com/xplip/pixel.git") | |
sys.path.append('./pixel') | |
from transformers import set_seed | |
from pixel.src.pixel import ( | |
PIXELConfig, | |
PIXELForPreTraining, | |
SpanMaskingGenerator, | |
PyGameTextRenderer, | |
get_transforms, | |
resize_model_embeddings, | |
truncate_decoder_pos_embeddings, | |
get_attention_mask | |
) | |
model_name_or_path = "Team-PIXEL/pixel-base" | |
max_seq_length = 529 | |
text_renderer = PyGameTextRenderer.from_pretrained(model_name_or_path, max_seq_length=max_seq_length) | |
config = PIXELConfig.from_pretrained(model_name_or_path) | |
model = PIXELForPreTraining.from_pretrained(model_name_or_path, config=config) | |
def clip(x: torch.Tensor): | |
x = torch.einsum("chw->hwc", x) | |
x = torch.clip(x * 255, 0, 255) | |
x = torch.einsum("hwc->chw", x) | |
return x | |
def get_image(img: torch.Tensor, do_clip: bool = True): | |
if do_clip: | |
img = clip(img) | |
img = torchvision.utils.make_grid(img, normalize=True) | |
image = Image.fromarray( | |
img.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy() | |
) | |
return image | |
def inference(text: str, mask_ratio: float = 0.25, max_span_length: int = 6, seed: int = 42): | |
config.update({"mask_ratio": mask_ratio}) | |
resize_model_embeddings(model, max_seq_length) | |
truncate_decoder_pos_embeddings(model, max_seq_length) | |
set_seed(seed) | |
transforms = get_transforms( | |
do_resize=True, | |
size=(text_renderer.pixels_per_patch, text_renderer.pixels_per_patch * text_renderer.max_seq_length), | |
) | |
encoding = text_renderer(text=text) | |
attention_mask = get_attention_mask( | |
num_text_patches=encoding.num_text_patches, seq_length=text_renderer.max_seq_length | |
) | |
img = transforms(Image.fromarray(encoding.pixel_values)).unsqueeze(0) | |
attention_mask = attention_mask.unsqueeze(0) | |
inputs = {"pixel_values": img.float(), "attention_mask": attention_mask} | |
mask_generator = SpanMaskingGenerator( | |
num_patches=text_renderer.max_seq_length, | |
num_masking_patches=math.ceil(mask_ratio * text_renderer.max_seq_length), | |
max_span_length=max_span_length, | |
spacing="span" | |
) | |
mask = torch.tensor(mask_generator(num_text_patches=(encoding.num_text_patches + 1))).unsqueeze(0) | |
inputs.update({"patch_mask": mask}) | |
model.eval() | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
predictions = model.unpatchify(outputs["logits"]).detach().cpu().squeeze() | |
mask = outputs["mask"].detach().cpu() | |
mask = mask.unsqueeze(-1).repeat(1, 1, text_renderer.pixels_per_patch ** 2 * 3) | |
mask = model.unpatchify(mask).squeeze() | |
attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, text_renderer.pixels_per_patch ** 2 * 3) | |
attention_mask = model.unpatchify(attention_mask).squeeze() | |
original_img = model.unpatchify(model.patchify(img)).squeeze() | |
im_masked = original_img * (1 - (torch.bitwise_and(mask == 1, attention_mask == 1)).long()) | |
masked_predictions = predictions * mask * attention_mask | |
reconstruction = im_masked + masked_predictions | |
return [get_image(original_img), get_image(im_masked), get_image(masked_predictions, do_clip=False), get_image(reconstruction, do_clip=False)] | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown("## PIXEL Masked Autoencoding") | |
gr.Markdown("Gradio demo for [PIXEL](https://huggingface.co/Team-PIXEL/pixel-base), introduced in [Language Modelling with Pixels](https://arxiv.org/abs/2207.06991). To use it, simply input your piece of text or click one of the examples to load them. Read more at the links below.") | |
with gr.Row(): | |
with gr.Column(): | |
tb_text = gr.Textbox(label="Text") | |
sl_ratio = gr.Slider(0.01, 1.0, step=0.01, value=0.25, label="Span masking ratio") | |
sl_len = gr.Slider(1, 6, step=1, value=6, label="Masking max span length") | |
sl_seed = gr.Slider(0, 1000, step=1, value=42, label="Random seed") | |
btn = gr.Button("Run") | |
with gr.Column(): | |
out_original = gr.Image(label="Original") | |
out_masked = gr.Image(label="Masked") | |
out_masked_pred = gr.Image(label="Masked Predictions") | |
out_reconstruction = gr.Image(label="Reconstruction") | |
btn.click(fn=inference, inputs=[tb_text, sl_ratio, sl_len, sl_seed], outputs=[out_original, out_masked, out_masked_pred, out_reconstruction]) | |
demo.launch(debug=True) |