File size: 4,173 Bytes
1e83452
 
11c1c09
 
 
 
 
c3e97a0
4014a39
fb37d1d
c07895d
dcf8e95
 
0407a7f
dcf8e95
1e83452
 
 
 
11c1c09
903237b
 
 
 
 
 
2303f25
1e83452
013667b
11c1c09
0dbf527
 
 
 
 
0407a7f
 
7c77dd1
1b792af
11c1c09
 
22a8481
87ee71d
 
 
 
 
 
 
11c1c09
1e83452
87ee71d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e83452
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
from huggingface_hub import InferenceClient
import spaces
import torch
import os
from huggingface_hub import login
from PIL import Image
from threading import Thread
import platform
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
import time
from transformers import pipeline




"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
# client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
print(f"Is CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
print(f"CUDA version: {torch.version.cuda}")
print(f"Python version: {platform.python_version()}")
print(f"Pytorch version: {torch.__version__}")
print(f"Gradio version: {gr. __version__}")
duration=10

login(token = os.getenv('gemma'))

# messages = [
#     {"role": "user", "content": "Who are you?"},
# ]
# pipe = pipeline("image-text-to-text", model="google/gemma-3-4b-it")
# print(pipe(messages))


ckpt = "google/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(ckpt, torch_dtype=torch.bfloat16).to("cuda")
processor = AutoProcessor.from_pretrained(ckpt)

@spaces.GPU(duration=duration)
def bot_streaming(message, history, max_new_tokens=250):
    
    txt = message["text"]
    ext_buffer = f"{txt}"
    
    messages= [] 
    images = []
    

    for i, msg in enumerate(history): 
        if isinstance(msg[0], tuple):
            messages.append({"role": "user", "content": [{"type": "text", "text": history[i+1][0]}, {"type": "image"}]})
            messages.append({"role": "assistant", "content": [{"type": "text", "text": history[i+1][1]}]})
            images.append(Image.open(msg[0][0]).convert("RGB"))
        elif isinstance(history[i-1], tuple) and isinstance(msg[0], str):
            # messages are already handled
            pass
        elif isinstance(history[i-1][0], str) and isinstance(msg[0], str): # text only turn
            messages.append({"role": "user", "content": [{"type": "text", "text": msg[0]}]})
            messages.append({"role": "assistant", "content": [{"type": "text", "text": msg[1]}]})

    # add current message
    if len(message["files"]) == 1:
        
        if isinstance(message["files"][0], str): # examples
            image = Image.open(message["files"][0]).convert("RGB")
        else: # regular input
            image = Image.open(message["files"][0]["path"]).convert("RGB")
        images.append(image)
        messages.append({"role": "user", "content": [{"type": "text", "text": txt}, {"type": "image"}]})
    else:
        messages.append({"role": "user", "content": [{"type": "text", "text": txt}]})


    texts = processor.apply_chat_template(messages, add_generation_prompt=True)

    if images == []:
        inputs = processor(text=texts, return_tensors="pt").to("cuda")
    else:
        inputs = processor(text=texts, images=images, return_tensors="pt").to("cuda")
    streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True)

    generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
    generated_text = ""
    
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    buffer = ""
    
    for new_text in streamer:
        buffer += new_text
        generated_text_without_prompt = buffer
        time.sleep(0.01)
        yield buffer


demo = gr.ChatInterface(fn=bot_streaming, 
                        title="Multimodal Gemma 3 Model by Google", 
      textbox=gr.MultimodalTextbox(), 
      additional_inputs = [gr.Slider(
              minimum=10,
              maximum=500,
              value=250,
              step=10,
              label="Maximum number of new tokens to generate",
          )
        ],
      cache_examples=False,
      description="Upload an image, and start chatting about it, or just enter any text into the prompt to start.",
      stop_btn="Stop Generation", 
      fill_height=True,
    multimodal=True)
    
demo.launch(debug=True)