File size: 4,080 Bytes
e352103
36be50d
fe4fa5b
36be50d
e352103
 
 
 
36be50d
0a651e1
e352103
 
0b32b82
 
e352103
0a651e1
ad382c8
e352103
e4c787e
e352103
fe4fa5b
36be50d
 
 
 
fe4fa5b
36be50d
fe4fa5b
ab94263
 
36be50d
 
e352103
 
 
 
 
 
36be50d
e352103
 
0029ec4
 
 
36be50d
0029ec4
 
 
 
 
 
 
 
fe4fa5b
 
 
 
 
 
 
 
0029ec4
e352103
 
fe4fa5b
 
36be50d
 
 
 
6d64276
 
36be50d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe4fa5b
 
36be50d
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer
from transformers.image_utils import load_image
from threading import Thread
import re
import time
import torch
import spaces
#import subprocess
#subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)


processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct-250M")
model = AutoModelForVision2Seq.from_pretrained("HuggingFaceTB/SmolVLM-Instruct-250M", 
        torch_dtype=torch.bfloat16,
        #_attn_implementation="flash_attention_2"
        ).to("cuda")

@spaces.GPU
def model_inference(
    input_dict, history
): 
    text = input_dict["text"]
    print(input_dict["files"])
    if len(input_dict["files"]) > 1:
      images = [load_image(image) for image in input_dict["files"]]
    elif len(input_dict["files"]) == 1:
      images = [load_image(input_dict["files"][0])]   
    else:
      images = []
    

    if text == "" and not images:
        gr.Error("Please input a query and optionally image(s).")

    if text == "" and images:
        gr.Error("Please input a text query along the image(s).")

    


    resulting_messages = [
                {
                    "role": "user",
                    "content": [{"type": "image"} for _ in range(len(images))] + [
                        {"type": "text", "text": text}
                    ]
                }
            ]
    prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
    inputs = processor(text=prompt, images=[images], return_tensors="pt")
    inputs = {k: v.to("cuda") for k, v in inputs.items()}
    generation_args = {
        "input_ids": inputs.input_ids,
        "pixel_values": inputs.pixel_values,
        "attention_mask": inputs.attention_mask,
        "num_return_sequences": 1,
        "no_repeat_ngram_size": 2,
        "temperature": 0.7,
        "max_new_tokens": 500,
        "min_new_tokens": 10,       
    }

    # Generate
    streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
    generation_args = dict(inputs, streamer=streamer, max_new_tokens=500)
    generated_text = ""

    thread = Thread(target=model.generate, kwargs=generation_args)
    thread.start()

    yield "..."
    buffer = ""
    
      
    for new_text in streamer:
    
      buffer += new_text
      generated_text_without_prompt = buffer#[len(ext_buffer):]
      time.sleep(0.01)
      yield buffer


examples=[
              [{"text": "What art era do these artpieces belong to?", "files": ["example_images/rococo.jpg", "example_images/rococo_1.jpg"]}, "Greedy", 0.4, 512, 1.2, 0.8],
              [{"text": "I'm planning a visit to this temple, give me travel tips.", "files": ["example_images/examples_wat_arun.jpg"]},  "Greedy", 0.4, 512, 1.2, 0.8],
              [{"text":  "What is the due date and the invoice date?", "files": ["example_images/examples_invoice.png"]}, "Greedy", 0.4, 512, 1.2, 0.8],
              [{"text": "What is this UI about?", "files": ["example_images/s2w_example.png"]},   "Greedy", 0.4, 512, 1.2, 0.8],
              [{"text":  "Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]},  "Greedy", 0.4, 512, 1.2, 0.8],
      ]
demo = gr.ChatInterface(fn=model_inference, title="SmolVLM: Small yet Mighty 💫", 
                description="Play with [HuggingFaceTB/SmolVLM-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM-Instruct) in this demo. To get started, upload an image and text or try one of the examples. This checkpoint works best with single turn conversations, so clear the conversation after a single turn.",
                examples=examples,
                textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"), stop_btn="Stop Generation", multimodal=True,
                        ],cache_examples=False
                )
     
      
      

demo.launch(debug=True)