jonaschua commited on
Commit
87ee71d
·
verified ·
1 Parent(s): 10c78b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -64
app.py CHANGED
@@ -18,72 +18,82 @@ duration=None
18
  login(token = os.getenv('gemma'))
19
 
20
  ckpt = "google/gemma-3-4b-it"
21
- model = Gemma3ForConditionalGeneration.from_pretrained(
22
- ckpt, device_map="auto", torch_dtype=torch.bfloat16,
23
- )
24
  processor = AutoProcessor.from_pretrained(ckpt)
25
 
26
- # image = Image.open(requests.get(url, stream=True).raw)
27
- # prompt = "<start_of_image> in this image, there is"
28
- # model_inputs = processor(text=prompt, images=image, return_tensors="pt")
29
- # input_len = model_inputs["input_ids"].shape[-1]
30
-
31
- # with torch.inference_mode():
32
- # generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
33
- # generation = generation[0][input_len:]
34
-
35
- @spaces.GPU(duration=duration)
36
- def respond(message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p,):
37
- # messages = [{"role": "system", "content": system_message}]
38
 
39
- messages = [{
40
- "role": "user",
41
- "content": [
42
- {"type": "image", "url": "https://huggingface.co/spaces/big-vision/paligemma-hf/resolve/main/examples/password.jpg"},
43
- {"type": "text", "text": "What is the password?"}
44
- ]}]
45
-
46
- for val in history:
47
- if val[0]:
48
- messages.append({"role": "user", "content": val[0]})
49
- if val[1]:
50
- messages.append({"role": "assistant", "content": val[1]})
51
-
52
- messages.append({"role": "user", "content": message})
53
-
54
- response = ""
55
-
56
- # for message in client.chat_completion(messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p,):
57
- # token = message.choices[0].delta.content
58
-
59
- # response += token
60
- # yield response
61
-
62
-
63
- """
64
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
65
- """
66
-
67
-
68
- demo = gr.ChatInterface(
69
- respond,
70
- textbox=gr.MultimodalTextbox(),
71
- multimodal=True,
72
- stop_btn="Stop generation",
73
- additional_inputs=[
74
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
75
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
76
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
77
- gr.Slider(
78
- minimum=0.1,
79
- maximum=1.0,
80
- value=0.95,
81
- step=0.05,
82
- label="Top-p (nucleus sampling)",
83
- ),
84
- ],
85
- )
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- if __name__ == "__main__":
89
- demo.launch()
 
18
  login(token = os.getenv('gemma'))
19
 
20
  ckpt = "google/gemma-3-4b-it"
21
+ model = Gemma3ForConditionalGeneration.from_pretrained(ckpt, torch_dtype=torch.bfloat16,).to("cuda")
 
 
22
  processor = AutoProcessor.from_pretrained(ckpt)
23
 
24
+ @spaces.GPU
25
+ def bot_streaming(message, history, max_new_tokens=250):
26
+
27
+ txt = message["text"]
28
+ ext_buffer = f"{txt}"
29
+
30
+ messages= []
31
+ images = []
 
 
 
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ for i, msg in enumerate(history):
35
+ if isinstance(msg[0], tuple):
36
+ messages.append({"role": "user", "content": [{"type": "text", "text": history[i+1][0]}, {"type": "image"}]})
37
+ messages.append({"role": "assistant", "content": [{"type": "text", "text": history[i+1][1]}]})
38
+ images.append(Image.open(msg[0][0]).convert("RGB"))
39
+ elif isinstance(history[i-1], tuple) and isinstance(msg[0], str):
40
+ # messages are already handled
41
+ pass
42
+ elif isinstance(history[i-1][0], str) and isinstance(msg[0], str): # text only turn
43
+ messages.append({"role": "user", "content": [{"type": "text", "text": msg[0]}]})
44
+ messages.append({"role": "assistant", "content": [{"type": "text", "text": msg[1]}]})
45
+
46
+ # add current message
47
+ if len(message["files"]) == 1:
48
+
49
+ if isinstance(message["files"][0], str): # examples
50
+ image = Image.open(message["files"][0]).convert("RGB")
51
+ else: # regular input
52
+ image = Image.open(message["files"][0]["path"]).convert("RGB")
53
+ images.append(image)
54
+ messages.append({"role": "user", "content": [{"type": "text", "text": txt}, {"type": "image"}]})
55
+ else:
56
+ messages.append({"role": "user", "content": [{"type": "text", "text": txt}]})
57
+
58
+
59
+ texts = processor.apply_chat_template(messages, add_generation_prompt=True)
60
+
61
+ if images == []:
62
+ inputs = processor(text=texts, return_tensors="pt").to("cuda")
63
+ else:
64
+ inputs = processor(text=texts, images=images, return_tensors="pt").to("cuda")
65
+ streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True)
66
+
67
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
68
+ generated_text = ""
69
+
70
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
71
+ thread.start()
72
+ buffer = ""
73
+
74
+ for new_text in streamer:
75
+ buffer += new_text
76
+ generated_text_without_prompt = buffer
77
+ time.sleep(0.01)
78
+ yield buffer
79
+
80
+
81
+ demo = gr.ChatInterface(fn=bot_streaming,
82
+ title="Multimodal Gemma 3 Model by Google",
83
+ textbox=gr.MultimodalTextbox(),
84
+ additional_inputs = [gr.Slider(
85
+ minimum=10,
86
+ maximum=500,
87
+ value=250,
88
+ step=10,
89
+ label="Maximum number of new tokens to generate",
90
+ )
91
+ ],
92
+ cache_examples=False,
93
+ description="Upload an image, and start chatting about it, or just enter any text into the prompt to start.",
94
+ stop_btn="Stop Generation",
95
+ fill_height=True,
96
+ multimodal=True)
97
+
98
+ demo.launch(debug=True)
99