Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,223 Bytes
b8a0d2d 0109e78 8716c2f f5ecaf8 8081540 9a23baa f5ecaf8 8081540 0109e78 9dc7658 dec51b2 9a23baa 9dc7658 9a23baa 9dc7658 9a23baa 86a82e4 9a23baa 9dc7658 9a23baa 9dc7658 9a23baa 9dc7658 9a23baa 9dc7658 9a23baa 9dc7658 9a23baa 9dc7658 8716c2f 9a23baa d9dde0d f5ecaf8 8716c2f f5ecaf8 8716c2f f5ecaf8 9a23baa 86a82e4 0109e78 9a23baa 9dc7658 86a82e4 9a23baa 86a82e4 b8a0d2d 9a23baa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import gradio as gr
from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer
from threading import Thread
import re
import time
from PIL import Image
import torch
import spaces
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct")
model = AutoModelForVision2Seq.from_pretrained("HuggingFaceTB/SmolVLM-Instruct",
torch_dtype=torch.bfloat16,
).to("cuda")
@spaces.GPU
def model_inference(
input_dict, history, decoding_strategy, temperature, max_new_tokens,
repetition_penalty, top_p
):
text = input_dict["text"]
print(input_dict["files"])
# Process input images if provided.
if len(input_dict["files"]) > 1:
images = [Image.open(image).convert("RGB") for image in input_dict["files"]]
elif len(input_dict["files"]) == 1:
images = [Image.open(input_dict["files"][0]).convert("RGB")]
else:
images = []
# Validate input
if text == "" and not images:
gr.Error("Please input a query and optionally image(s).")
if text == "" and images:
gr.Error("Please input a text query along with the image(s).")
# Prepare prompt using the chat template.
resulting_messages = [{
"role": "user",
"content": [{"type": "image"} for _ in range(len(images))] + [
{"type": "text", "text": text}
]
}]
prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=[images], return_tensors="pt")
inputs = {k: v.to("cuda") for k, v in inputs.items()}
# Setup generation parameters.
generation_args = {
"max_new_tokens": max_new_tokens,
"repetition_penalty": repetition_penalty,
}
assert decoding_strategy in ["Greedy", "Top P Sampling"]
if decoding_strategy == "Greedy":
generation_args["do_sample"] = False
elif decoding_strategy == "Top P Sampling":
generation_args["temperature"] = temperature
generation_args["do_sample"] = True
generation_args["top_p"] = top_p
generation_args.update(inputs)
# Generate output with a streaming approach.
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
generated_text = ""
thread = Thread(target=model.generate, kwargs=generation_args)
thread.start()
yield "..."
buffer = ""
for new_text in streamer:
buffer += new_text
time.sleep(0.01)
yield buffer
# Define the ChatInterface without examples.
demo = gr.ChatInterface(
fn=model_inference,
description="# **SmolVLM Video Infer**",
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"),
stop_btn="Stop Generation",
multimodal=True,
additional_inputs=[
gr.Radio(
["Top P Sampling", "Greedy"],
value="Greedy",
label="Decoding strategy",
info="Higher values is equivalent to sampling more low-probability tokens.",
),
gr.Slider(
minimum=0.0,
maximum=5.0,
value=0.4,
step=0.1,
interactive=True,
label="Sampling temperature",
info="Higher values will produce more diverse outputs.",
),
gr.Slider(
minimum=8,
maximum=1024,
value=512,
step=1,
interactive=True,
label="Maximum number of new tokens to generate",
),
gr.Slider(
minimum=0.01,
maximum=5.0,
value=1.2,
step=0.01,
interactive=True,
label="Repetition penalty",
info="1.0 is equivalent to no penalty",
),
gr.Slider(
minimum=0.01,
maximum=0.99,
value=0.8,
step=0.01,
interactive=True,
label="Top P",
info="Higher values is equivalent to sampling more low-probability tokens.",
)
],
cache_examples=False
)
demo.launch(debug=True)
|