Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
import random | |
import numpy as np | |
from PIL import Image | |
import imagehash | |
import cv2 | |
import os | |
import spaces | |
import subprocess | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension | |
from transformers.image_transforms import resize, to_channel_dimension_format | |
from typing import List | |
from PIL import Image | |
from collections import Counter | |
from datasets import load_dataset, concatenate_datasets | |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
DEVICE = torch.device("cuda") | |
PROCESSOR = AutoProcessor.from_pretrained( | |
"HuggingFaceM4/idefics2_raven_finetuned", | |
token=os.environ["HF_AUTH_TOKEN"], | |
) | |
MODEL = AutoModelForCausalLM.from_pretrained( | |
"HuggingFaceM4/idefics2_raven_finetuned", | |
trust_remote_code=True, | |
torch_dtype=torch.bfloat16, | |
token=os.environ["HF_AUTH_TOKEN"], | |
).to(DEVICE) | |
if MODEL.config.use_resampler: | |
image_seq_len = MODEL.config.perceiver_config.resampler_n_latents | |
else: | |
image_seq_len = ( | |
MODEL.config.vision_config.image_size // MODEL.config.vision_config.patch_size | |
) ** 2 | |
BOS_TOKEN = PROCESSOR.tokenizer.bos_token | |
BAD_WORDS_IDS = PROCESSOR.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids | |
DATASET = load_dataset("HuggingFaceM4/RAVEN_rendered", split="validation", token=os.environ["HF_AUTH_TOKEN"]) | |
## Utils | |
def convert_to_rgb(image): | |
# `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background | |
# for transparent images. The call to `alpha_composite` handles this case | |
if image.mode == "RGB": | |
return image | |
image_rgba = image.convert("RGBA") | |
background = Image.new("RGBA", image_rgba.size, (255, 255, 255)) | |
alpha_composite = Image.alpha_composite(background, image_rgba) | |
alpha_composite = alpha_composite.convert("RGB") | |
return alpha_composite | |
# The processor is the same as the Idefics processor except for the BICUBIC interpolation inside siglip, | |
# so this is a hack in order to redefine ONLY the transform method | |
def custom_transform(x): | |
x = convert_to_rgb(x) | |
x = to_numpy_array(x) | |
x = resize(x, (960, 960), resample=PILImageResampling.BILINEAR) | |
x = PROCESSOR.image_processor.rescale(x, scale=1 / 255) | |
x = PROCESSOR.image_processor.normalize( | |
x, | |
mean=PROCESSOR.image_processor.image_mean, | |
std=PROCESSOR.image_processor.image_std | |
) | |
x = to_channel_dimension_format(x, ChannelDimension.FIRST) | |
x = torch.tensor(x) | |
return x | |
def pixel_difference(image1, image2): | |
def color(im): | |
arr = np.array(im).flatten() | |
arr_list = arr.tolist() | |
counts = Counter(arr_list) | |
most_common = counts.most_common(2) | |
if most_common[0][0] == 255: | |
return most_common[1][0] | |
else: | |
return most_common[0][0] | |
def canny_edges(im): | |
im = cv2.Canny(np.array(im), 50, 100) | |
im[im!=0] = 255 | |
return Image.fromarray(im) | |
def phash(im): | |
return imagehash.phash(canny_edges(im), hash_size=32) | |
def surface(im): | |
return (np.array(im) != 255).sum() | |
color_diff = np.abs(color(image1) - color(image2)) | |
hash_diff = phash(image1) - phash(image2) | |
surface_diff = np.abs(surface(image1) - surface(image2)) | |
if int(hash_diff/7) < 10: | |
return color_diff < 10 or int(surface_diff / (160 * 160) * 100) < 10 | |
elif color_diff < 10: | |
return int(surface_diff / (160 * 160) * 100) < 10 or int(hash_diff/7) < 10 | |
elif int(surface_diff / (160 * 160) * 100) < 10: | |
return int(hash_diff/7) < 10 or color_diff < 10 | |
else: | |
return False | |
# End of Utils | |
def load_sample(): | |
n = len(DATASET) | |
found_sample = False | |
while not found_sample: | |
idx = random.randint(0, n) | |
sample = DATASET[idx] | |
found_sample = True | |
return sample["image"], sample["label"], "", "", "" | |
def model_inference( | |
image, | |
): | |
if image is None: | |
raise gr.Error("Load a new sample first!") | |
# return "A" | |
inputs = PROCESSOR.tokenizer( | |
f"{BOS_TOKEN}User:<fake_token_around_image>{'<image>' * image_seq_len}<fake_token_around_image>Which figure should complete the logical sequence?<end_of_utterance>\nAssistant:", | |
return_tensors="pt", | |
add_special_tokens=False, | |
) | |
inputs["pixel_values"] = PROCESSOR.image_processor( | |
[image], | |
transform=custom_transform | |
) | |
inputs = { | |
k: v.to(DEVICE) | |
for k, v in inputs.items() | |
} | |
generation_kwargs = dict( | |
inputs, | |
bad_words_ids=BAD_WORDS_IDS, | |
max_length=128, | |
) | |
# Regular generation version | |
generated_ids = MODEL.generate(**generation_kwargs) | |
generated_text = PROCESSOR.batch_decode( | |
generated_ids, | |
skip_special_tokens=True | |
)[0] | |
return generated_text[-1] | |
model_prediction = gr.Label( | |
label="AI's guess", | |
visible=True, | |
) | |
user_prediction = gr.Label( | |
label="Your guess", | |
visible=True, | |
) | |
result = gr.TextArea( | |
label="Win or lose?", | |
visible=True, | |
lines=1, | |
max_lines=1, | |
interactive=False, | |
) | |
css = """ | |
.gradio-container{max-width: 1000px!important} | |
h1{display: flex;align-items: center;justify-content: center;gap: .25em} | |
*{transition: width 0.5s ease, flex-grow 0.5s ease} | |
""" | |
with gr.Blocks(title="Beat the AI", theme=gr.themes.Base(), css=css) as demo: | |
gr.Markdown( | |
""" | |
# Can you beat the AI at RAVEN puzzles? | |
*This demo features an early fine-tuned version of our forthcoming Idefics2 model (read about idefics1 [here](https://huggingface.co/HuggingFaceM4/idefics-80b-instruct). The model was specifically fine-tuned on the [RAVEN](https://huggingface.co/datasets/HuggingFaceM4/RAVEN) dataset and reached 91% accuracy on the validation set.* | |
RAVE Progressive Matrices are abstract visual reasoning puzzles. The panels describe logical sequences of shapes and colors (row by row). One is asked to find the option that completes the 3rd sequence following the same logic described by the first two sequences. We recommend looking at the images on a full screen with enough brightness given that some options differ by small differences in sizes and nuances of colors. | |
To get started, load a new puzzle. π§ | |
""" | |
) | |
load_new_sample = gr.Button(value="Load a new puzzle") | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=4, min_width=250) as upload_area: | |
imagebox = gr.Image( | |
image_mode="L", | |
type="pil", | |
visible=True, | |
sources=[], | |
) | |
with gr.Column(scale=4): | |
with gr.Row(): | |
a = gr.Button(value="A", min_width=1) | |
b = gr.Button(value="B", min_width=1) | |
c = gr.Button(value="C", min_width=1) | |
d = gr.Button(value="D", min_width=1) | |
with gr.Row(): | |
e = gr.Button(value="E", min_width=1) | |
f = gr.Button(value="F", min_width=1) | |
g = gr.Button(value="G", min_width=1) | |
h = gr.Button(value="H", min_width=1) | |
with gr.Row(): | |
user_prediction.render() | |
model_prediction.render() | |
solution = gr.TextArea( | |
label="Solution", | |
visible=False, | |
lines=1, | |
max_lines=1, | |
interactive=False, | |
) | |
with gr.Row(): | |
result.render() | |
def result_string(model_pred, user_pred, solution): | |
if solution == "": | |
return "" | |
solution_letter = chr(ord('A') + int(solution)) | |
solution_string = f"The correct answer is {solution_letter}." | |
win_or_loose = "π₯" if user_pred == solution_letter else "π" | |
if user_pred == solution_letter and model_pred == solution_letter: | |
comparison_string = "Both you and the AI got it correctly!" | |
elif user_pred == solution_letter and model_pred != solution_letter: | |
comparison_string = "You beat the AI!" | |
elif user_pred != solution_letter and model_pred != solution_letter: | |
comparison_string = "Both you and the AI got it wrong!" | |
elif user_pred != solution_letter and model_pred == solution_letter: | |
comparison_string = "The AI beat you!" | |
return f"{win_or_loose} {comparison_string} {solution_string}" | |
load_new_sample.click( | |
fn=load_sample, | |
inputs=[], | |
outputs=[imagebox, solution, model_prediction, user_prediction, result] | |
) | |
gr.on( | |
triggers=[ | |
a.click, | |
b.click, | |
c.click, | |
d.click, | |
e.click, | |
f.click, | |
g.click, | |
h.click, | |
], | |
fn=model_inference, | |
inputs=[imagebox], | |
outputs=[model_prediction], | |
).then( | |
fn=result_string, | |
inputs=[model_prediction, user_prediction, solution], | |
outputs=[result], | |
) | |
a.click(fn=lambda: "A", inputs=[], outputs=[user_prediction]) | |
b.click(fn=lambda: "B", inputs=[], outputs=[user_prediction]) | |
c.click(fn=lambda: "C", inputs=[], outputs=[user_prediction]) | |
d.click(fn=lambda: "D", inputs=[], outputs=[user_prediction]) | |
e.click(fn=lambda: "E", inputs=[], outputs=[user_prediction]) | |
f.click(fn=lambda: "F", inputs=[], outputs=[user_prediction]) | |
g.click(fn=lambda: "G", inputs=[], outputs=[user_prediction]) | |
h.click(fn=lambda: "H", inputs=[], outputs=[user_prediction]) | |
demo.load() | |
demo.queue(max_size=40, api_open=False) | |
demo.launch(max_threads=400) |