File size: 2,861 Bytes
14e2513
 
216b96d
 
6cae924
 
 
 
 
 
14e2513
 
 
 
 
216b96d
f9fa47c
 
 
216b96d
5589f0c
 
 
 
 
 
6cae924
 
 
 
 
 
 
 
216b96d
 
6cae924
 
 
 
216b96d
 
 
6cae924
216b96d
 
 
 
 
b679c08
216b96d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b679c08
216b96d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os

from vllm import LLM, SamplingParams
import gradio as gr

from PIL import Image
from io import BytesIO
import base64
import requests

from huggingface_hub import login
import os

login(os.environ["HF_TOKEN"])

repo_id = "mistral-community/pixtral-12b-240910" #Replace to the model you would like to use
sampling_params = SamplingParams(max_tokens=8192, temperature=0.7)
max_tokens_per_img = 4096
max_img_per_msg = 5

llm = LLM(model="mistralai/Pixtral-12B-2409",
          tokenizer_mode="mistral",
          max_model_len=65536,
          max_num_batched_tokens=max_img_per_msg * max_tokens_per_img,
          limit_mm_per_prompt={"image": max_img_per_msg})  # Name or path of your model

def encode_image(image: Image.Image, image_format="PNG") -> str:
    im_file = BytesIO()
    image.save(im_file, format=image_format)
    im_bytes = im_file.getvalue()
    im_64 = base64.b64encode(im_bytes).decode("utf-8")
    return im_64


# @spaces.GPU #[uncomment to use ZeroGPU]
def infer(image_url, prompt, progress=gr.Progress(track_tqdm=True)):
    image = Image.open(BytesIO(requests.get(image_url).content))
    image = image.resize((3844, 2408))
    new_image_url = f"data:image/png;base64,{encode_image(image, image_format='PNG')}"

    messages = [
        {
            "role": "user",
            "content": [{"type": "text", "text": prompt}, {"type": "image_url", "image_url": {"url": new_image_url}}]
        },
    ]

    outputs = llm.chat(messages, sampling_params=sampling_params)

    return outputs[0].outputs[0].text


example_images = ["https://picsum.photos/id/237/200/300"]
example_prompts = ["What do you see in this image?"]

css = """
#col-container {
    margin: 0 auto;
    max-width: 640px;
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown(f"""
        # Mistral Pixtral 12B
        """)

        with gr.Row():
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=2,
                placeholder="Enter your prompt",
                container=False,
            )

            image_url = gr.Text(
                label="Image URL",
                show_label=False,
                max_lines=1,
                placeholder="Enter your image URL",
                container=False,
            )

            run_button = gr.Button("Run", scale=0)

        result = gr.Textbox(
            show_label=False
        )

        gr.Examples(
            examples=example_images,
            inputs=[image_url]
        )

        gr.Examples(
            examples=example_prompts,
            inputs=[prompt]
        )
    gr.on(
        triggers=[run_button.click, image_url.submit, prompt.submit],
        fn=infer,
        inputs=[image_url, prompt],
        outputs=[result]
    )

demo.queue().launch()