Spaces:
Sleeping
Sleeping
File size: 4,080 Bytes
e352103 36be50d fe4fa5b 36be50d e352103 36be50d 0a651e1 e352103 0b32b82 e352103 0a651e1 ad382c8 e352103 e4c787e e352103 fe4fa5b 36be50d fe4fa5b 36be50d fe4fa5b ab94263 36be50d e352103 36be50d e352103 0029ec4 36be50d 0029ec4 fe4fa5b 0029ec4 e352103 fe4fa5b 36be50d 6d64276 36be50d fe4fa5b 36be50d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import gradio as gr
from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer
from transformers.image_utils import load_image
from threading import Thread
import re
import time
import torch
import spaces
#import subprocess
#subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct-250M")
model = AutoModelForVision2Seq.from_pretrained("HuggingFaceTB/SmolVLM-Instruct-250M",
torch_dtype=torch.bfloat16,
#_attn_implementation="flash_attention_2"
).to("cuda")
@spaces.GPU
def model_inference(
input_dict, history
):
text = input_dict["text"]
print(input_dict["files"])
if len(input_dict["files"]) > 1:
images = [load_image(image) for image in input_dict["files"]]
elif len(input_dict["files"]) == 1:
images = [load_image(input_dict["files"][0])]
else:
images = []
if text == "" and not images:
gr.Error("Please input a query and optionally image(s).")
if text == "" and images:
gr.Error("Please input a text query along the image(s).")
resulting_messages = [
{
"role": "user",
"content": [{"type": "image"} for _ in range(len(images))] + [
{"type": "text", "text": text}
]
}
]
prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=[images], return_tensors="pt")
inputs = {k: v.to("cuda") for k, v in inputs.items()}
generation_args = {
"input_ids": inputs.input_ids,
"pixel_values": inputs.pixel_values,
"attention_mask": inputs.attention_mask,
"num_return_sequences": 1,
"no_repeat_ngram_size": 2,
"temperature": 0.7,
"max_new_tokens": 500,
"min_new_tokens": 10,
}
# Generate
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_args = dict(inputs, streamer=streamer, max_new_tokens=500)
generated_text = ""
thread = Thread(target=model.generate, kwargs=generation_args)
thread.start()
yield "..."
buffer = ""
for new_text in streamer:
buffer += new_text
generated_text_without_prompt = buffer#[len(ext_buffer):]
time.sleep(0.01)
yield buffer
examples=[
[{"text": "What art era do these artpieces belong to?", "files": ["example_images/rococo.jpg", "example_images/rococo_1.jpg"]}, "Greedy", 0.4, 512, 1.2, 0.8],
[{"text": "I'm planning a visit to this temple, give me travel tips.", "files": ["example_images/examples_wat_arun.jpg"]}, "Greedy", 0.4, 512, 1.2, 0.8],
[{"text": "What is the due date and the invoice date?", "files": ["example_images/examples_invoice.png"]}, "Greedy", 0.4, 512, 1.2, 0.8],
[{"text": "What is this UI about?", "files": ["example_images/s2w_example.png"]}, "Greedy", 0.4, 512, 1.2, 0.8],
[{"text": "Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}, "Greedy", 0.4, 512, 1.2, 0.8],
]
demo = gr.ChatInterface(fn=model_inference, title="SmolVLM: Small yet Mighty 💫",
description="Play with [HuggingFaceTB/SmolVLM-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM-Instruct) in this demo. To get started, upload an image and text or try one of the examples. This checkpoint works best with single turn conversations, so clear the conversation after a single turn.",
examples=examples,
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"), stop_btn="Stop Generation", multimodal=True,
],cache_examples=False
)
demo.launch(debug=True)
|