ai_raven / app.py
VictorSanh's picture
Update app.py
ed2c0ea verified
import torch
import gradio as gr
import random
import numpy as np
from PIL import Image
import imagehash
import cv2
import os
import spaces
import subprocess
from transformers import AutoProcessor, AutoModelForCausalLM
from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension
from transformers.image_transforms import resize, to_channel_dimension_format
from typing import List
from PIL import Image
from collections import Counter
from datasets import load_dataset, concatenate_datasets
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
DEVICE = torch.device("cuda")
PROCESSOR = AutoProcessor.from_pretrained(
"HuggingFaceM4/idefics2_raven_finetuned",
token=os.environ["HF_AUTH_TOKEN"],
)
MODEL = AutoModelForCausalLM.from_pretrained(
"HuggingFaceM4/idefics2_raven_finetuned",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
token=os.environ["HF_AUTH_TOKEN"],
).to(DEVICE)
if MODEL.config.use_resampler:
image_seq_len = MODEL.config.perceiver_config.resampler_n_latents
else:
image_seq_len = (
MODEL.config.vision_config.image_size // MODEL.config.vision_config.patch_size
) ** 2
BOS_TOKEN = PROCESSOR.tokenizer.bos_token
BAD_WORDS_IDS = PROCESSOR.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
DATASET = load_dataset("HuggingFaceM4/RAVEN_rendered", split="validation", token=os.environ["HF_AUTH_TOKEN"])
## Utils
def convert_to_rgb(image):
# `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
# for transparent images. The call to `alpha_composite` handles this case
if image.mode == "RGB":
return image
image_rgba = image.convert("RGBA")
background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
alpha_composite = Image.alpha_composite(background, image_rgba)
alpha_composite = alpha_composite.convert("RGB")
return alpha_composite
# The processor is the same as the Idefics processor except for the BICUBIC interpolation inside siglip,
# so this is a hack in order to redefine ONLY the transform method
def custom_transform(x):
x = convert_to_rgb(x)
x = to_numpy_array(x)
x = resize(x, (960, 960), resample=PILImageResampling.BILINEAR)
x = PROCESSOR.image_processor.rescale(x, scale=1 / 255)
x = PROCESSOR.image_processor.normalize(
x,
mean=PROCESSOR.image_processor.image_mean,
std=PROCESSOR.image_processor.image_std
)
x = to_channel_dimension_format(x, ChannelDimension.FIRST)
x = torch.tensor(x)
return x
def pixel_difference(image1, image2):
def color(im):
arr = np.array(im).flatten()
arr_list = arr.tolist()
counts = Counter(arr_list)
most_common = counts.most_common(2)
if most_common[0][0] == 255:
return most_common[1][0]
else:
return most_common[0][0]
def canny_edges(im):
im = cv2.Canny(np.array(im), 50, 100)
im[im!=0] = 255
return Image.fromarray(im)
def phash(im):
return imagehash.phash(canny_edges(im), hash_size=32)
def surface(im):
return (np.array(im) != 255).sum()
color_diff = np.abs(color(image1) - color(image2))
hash_diff = phash(image1) - phash(image2)
surface_diff = np.abs(surface(image1) - surface(image2))
if int(hash_diff/7) < 10:
return color_diff < 10 or int(surface_diff / (160 * 160) * 100) < 10
elif color_diff < 10:
return int(surface_diff / (160 * 160) * 100) < 10 or int(hash_diff/7) < 10
elif int(surface_diff / (160 * 160) * 100) < 10:
return int(hash_diff/7) < 10 or color_diff < 10
else:
return False
# End of Utils
def load_sample():
n = len(DATASET)
found_sample = False
while not found_sample:
idx = random.randint(0, n)
sample = DATASET[idx]
found_sample = True
return sample["image"], sample["label"], "", "", ""
@spaces.GPU(duration=180)
def model_inference(
image,
):
if image is None:
raise gr.Error("Load a new sample first!")
# return "A"
inputs = PROCESSOR.tokenizer(
f"{BOS_TOKEN}User:<fake_token_around_image>{'<image>' * image_seq_len}<fake_token_around_image>Which figure should complete the logical sequence?<end_of_utterance>\nAssistant:",
return_tensors="pt",
add_special_tokens=False,
)
inputs["pixel_values"] = PROCESSOR.image_processor(
[image],
transform=custom_transform
)
inputs = {
k: v.to(DEVICE)
for k, v in inputs.items()
}
generation_kwargs = dict(
inputs,
bad_words_ids=BAD_WORDS_IDS,
max_length=128,
)
# Regular generation version
generated_ids = MODEL.generate(**generation_kwargs)
generated_text = PROCESSOR.batch_decode(
generated_ids,
skip_special_tokens=True
)[0]
return generated_text[-1]
model_prediction = gr.Label(
label="AI's guess",
visible=True,
)
user_prediction = gr.Label(
label="Your guess",
visible=True,
)
result = gr.TextArea(
label="Win or lose?",
visible=True,
lines=1,
max_lines=1,
interactive=False,
)
css = """
.gradio-container{max-width: 1000px!important}
h1{display: flex;align-items: center;justify-content: center;gap: .25em}
*{transition: width 0.5s ease, flex-grow 0.5s ease}
"""
with gr.Blocks(title="Beat the AI", theme=gr.themes.Base(), css=css) as demo:
gr.Markdown(
"""
# Can you beat the AI at RAVEN puzzles?
*This demo features an early fine-tuned version of our forthcoming Idefics2 model (read about idefics1 [here](https://huggingface.co/HuggingFaceM4/idefics-80b-instruct). The model was specifically fine-tuned on the [RAVEN](https://huggingface.co/datasets/HuggingFaceM4/RAVEN) dataset and reached 91% accuracy on the validation set.*
RAVE Progressive Matrices are abstract visual reasoning puzzles. The panels describe logical sequences of shapes and colors (row by row). One is asked to find the option that completes the 3rd sequence following the same logic described by the first two sequences. We recommend looking at the images on a full screen with enough brightness given that some options differ by small differences in sizes and nuances of colors.
To get started, load a new puzzle. 🧠
"""
)
load_new_sample = gr.Button(value="Load a new puzzle")
with gr.Row(equal_height=True):
with gr.Column(scale=4, min_width=250) as upload_area:
imagebox = gr.Image(
image_mode="L",
type="pil",
visible=True,
sources=[],
)
with gr.Column(scale=4):
with gr.Row():
a = gr.Button(value="A", min_width=1)
b = gr.Button(value="B", min_width=1)
c = gr.Button(value="C", min_width=1)
d = gr.Button(value="D", min_width=1)
with gr.Row():
e = gr.Button(value="E", min_width=1)
f = gr.Button(value="F", min_width=1)
g = gr.Button(value="G", min_width=1)
h = gr.Button(value="H", min_width=1)
with gr.Row():
user_prediction.render()
model_prediction.render()
solution = gr.TextArea(
label="Solution",
visible=False,
lines=1,
max_lines=1,
interactive=False,
)
with gr.Row():
result.render()
def result_string(model_pred, user_pred, solution):
if solution == "":
return ""
solution_letter = chr(ord('A') + int(solution))
solution_string = f"The correct answer is {solution_letter}."
win_or_loose = "πŸ₯‡" if user_pred == solution_letter else "πŸ™ˆ"
if user_pred == solution_letter and model_pred == solution_letter:
comparison_string = "Both you and the AI got it correctly!"
elif user_pred == solution_letter and model_pred != solution_letter:
comparison_string = "You beat the AI!"
elif user_pred != solution_letter and model_pred != solution_letter:
comparison_string = "Both you and the AI got it wrong!"
elif user_pred != solution_letter and model_pred == solution_letter:
comparison_string = "The AI beat you!"
return f"{win_or_loose} {comparison_string} {solution_string}"
load_new_sample.click(
fn=load_sample,
inputs=[],
outputs=[imagebox, solution, model_prediction, user_prediction, result]
)
gr.on(
triggers=[
a.click,
b.click,
c.click,
d.click,
e.click,
f.click,
g.click,
h.click,
],
fn=model_inference,
inputs=[imagebox],
outputs=[model_prediction],
).then(
fn=result_string,
inputs=[model_prediction, user_prediction, solution],
outputs=[result],
)
a.click(fn=lambda: "A", inputs=[], outputs=[user_prediction])
b.click(fn=lambda: "B", inputs=[], outputs=[user_prediction])
c.click(fn=lambda: "C", inputs=[], outputs=[user_prediction])
d.click(fn=lambda: "D", inputs=[], outputs=[user_prediction])
e.click(fn=lambda: "E", inputs=[], outputs=[user_prediction])
f.click(fn=lambda: "F", inputs=[], outputs=[user_prediction])
g.click(fn=lambda: "G", inputs=[], outputs=[user_prediction])
h.click(fn=lambda: "H", inputs=[], outputs=[user_prediction])
demo.load()
demo.queue(max_size=40, api_open=False)
demo.launch(max_threads=400)