Spaces:
Running
Running
File size: 4,595 Bytes
e126020 0379c3a e126020 0379c3a e126020 0379c3a e126020 0379c3a e126020 0379c3a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import math
import os
import sys
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
import torchvision
os.system("git clone https://github.com/xplip/pixel.git")
sys.path.append('./pixel')
from transformers import set_seed
from pixel.src.pixel import (
PIXELConfig,
PIXELForPreTraining,
SpanMaskingGenerator,
PyGameTextRenderer,
get_transforms,
resize_model_embeddings,
truncate_decoder_pos_embeddings,
get_attention_mask
)
model_name_or_path = "Team-PIXEL/pixel-base"
max_seq_length = 529
text_renderer = PyGameTextRenderer.from_pretrained(model_name_or_path, max_seq_length=max_seq_length)
config = PIXELConfig.from_pretrained(model_name_or_path)
model = PIXELForPreTraining.from_pretrained(model_name_or_path, config=config)
def clip(x: torch.Tensor):
x = torch.einsum("chw->hwc", x)
x = torch.clip(x * 255, 0, 255)
x = torch.einsum("hwc->chw", x)
return x
def get_image(img: torch.Tensor, do_clip: bool = True):
if do_clip:
img = clip(img)
img = torchvision.utils.make_grid(img, normalize=True)
image = Image.fromarray(
img.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
)
return image
def inference(text: str, mask_ratio: float = 0.25, max_span_length: int = 6, seed: int = 42):
config.update({"mask_ratio": mask_ratio})
resize_model_embeddings(model, max_seq_length)
truncate_decoder_pos_embeddings(model, max_seq_length)
set_seed(seed)
transforms = get_transforms(
do_resize=True,
size=(text_renderer.pixels_per_patch, text_renderer.pixels_per_patch * text_renderer.max_seq_length),
)
encoding = text_renderer(text=text)
attention_mask = get_attention_mask(
num_text_patches=encoding.num_text_patches, seq_length=text_renderer.max_seq_length
)
img = transforms(Image.fromarray(encoding.pixel_values)).unsqueeze(0)
attention_mask = attention_mask.unsqueeze(0)
inputs = {"pixel_values": img.float(), "attention_mask": attention_mask}
mask_generator = SpanMaskingGenerator(
num_patches=text_renderer.max_seq_length,
num_masking_patches=math.ceil(mask_ratio * text_renderer.max_seq_length),
max_span_length=max_span_length,
spacing="span"
)
mask = torch.tensor(mask_generator(num_text_patches=(encoding.num_text_patches + 1))).unsqueeze(0)
inputs.update({"patch_mask": mask})
model.eval()
with torch.no_grad():
outputs = model(**inputs)
predictions = model.unpatchify(outputs["logits"]).detach().cpu().squeeze()
mask = outputs["mask"].detach().cpu()
mask = mask.unsqueeze(-1).repeat(1, 1, text_renderer.pixels_per_patch ** 2 * 3)
mask = model.unpatchify(mask).squeeze()
attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, text_renderer.pixels_per_patch ** 2 * 3)
attention_mask = model.unpatchify(attention_mask).squeeze()
original_img = model.unpatchify(model.patchify(img)).squeeze()
im_masked = original_img * (1 - (torch.bitwise_and(mask == 1, attention_mask == 1)).long())
masked_predictions = predictions * mask * attention_mask
reconstruction = im_masked + masked_predictions
return [get_image(original_img), get_image(im_masked), get_image(masked_predictions, do_clip=False), get_image(reconstruction, do_clip=False)]
demo = gr.Blocks()
with demo:
gr.Markdown("## PIXEL Masked Autoencoding")
gr.Markdown("Gradio demo for [PIXEL](https://huggingface.co/Team-PIXEL/pixel-base), introduced in [Language Modelling with Pixels](https://arxiv.org/abs/2207.06991). To use it, simply input your piece of text or click one of the examples to load them. Read more at the links below.")
with gr.Row():
with gr.Column():
tb_text = gr.Textbox(label="Text")
sl_ratio = gr.Slider(0.01, 1.0, step=0.01, value=0.25, label="Span masking ratio")
sl_len = gr.Slider(1, 6, step=1, value=6, label="Masking max span length")
sl_seed = gr.Slider(0, 1000, step=1, value=42, label="Random seed")
btn = gr.Button("Run")
with gr.Column():
out_original = gr.Image(label="Original")
out_masked = gr.Image(label="Masked")
out_masked_pred = gr.Image(label="Masked Predictions")
out_reconstruction = gr.Image(label="Reconstruction")
btn.click(fn=inference, inputs=[tb_text, sl_ratio, sl_len, sl_seed], outputs=[out_original, out_masked, out_masked_pred, out_reconstruction])
demo.launch(debug=True) |