merterbak commited on
Commit
7125168
·
verified ·
1 Parent(s): cd9a7ab
Files changed (1) hide show
  1. app.py +179 -0
app.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import torch
4
+ from PIL import Image
5
+ from pathlib import Path
6
+ from threading import Thread
7
+ from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
8
+ import spaces
9
+ import time
10
+
11
+ # model config
12
+ model_12b_name = "google/gemma-3-12b-it"
13
+ model_4b_name = "google/gemma-3-4b-it"
14
+ model_12b = Gemma3ForConditionalGeneration.from_pretrained(
15
+ model_12b_name,
16
+ device_map="auto",
17
+ torch_dtype=torch.bfloat16
18
+ ).eval()
19
+ processor_12b = AutoProcessor.from_pretrained(model_12b_name)
20
+ model_4b = Gemma3ForConditionalGeneration.from_pretrained(
21
+ model_4b_name,
22
+ device_map="auto",
23
+ torch_dtype=torch.bfloat16
24
+ ).eval()
25
+ processor_4b = AutoProcessor.from_pretrained(model_4b_name)
26
+ # I will add timestamp later
27
+ def extract_video_frames(video_path, num_frames=8):
28
+ cap = cv2.VideoCapture(video_path)
29
+ frames = []
30
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
31
+ step = max(total_frames // num_frames, 1)
32
+
33
+ for i in range(num_frames):
34
+ cap.set(cv2.CAP_PROP_POS_FRAMES, i * step)
35
+ ret, frame = cap.read()
36
+ if ret:
37
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
38
+ frames.append(Image.fromarray(frame))
39
+ cap.release()
40
+ return frames
41
+
42
+ def format_message(content, files):
43
+
44
+ message_content = []
45
+
46
+ if content:
47
+ parts = content.split('<image>')
48
+ for i, part in enumerate(parts):
49
+ if part.strip():
50
+ message_content.append({"type": "text", "text": part.strip()})
51
+ if i < len(parts) - 1 and files:
52
+ img = Image.open(files.pop(0))
53
+ message_content.append({"type": "image", "image": img})
54
+ for file in files:
55
+ file_path = file if isinstance(file, str) else file.name
56
+ if Path(file_path).suffix.lower() in ['.jpg', '.jpeg', '.png']:
57
+ img = Image.open(file_path)
58
+ message_content.append({"type": "image", "image": img})
59
+ elif Path(file_path).suffix.lower() in ['.mp4', '.mov']:
60
+ frames = extract_video_frames(file_path)
61
+ for frame in frames:
62
+ message_content.append({"type": "image", "image": frame})
63
+ return message_content
64
+
65
+ def format_conversation_history(chat_history):
66
+ messages = []
67
+ current_user_content = []
68
+ for item in chat_history:
69
+ role = item["role"]
70
+ content = item["content"]
71
+ if role == "user":
72
+ if isinstance(content, str):
73
+ current_user_content.append({"type": "text", "text": content})
74
+ elif isinstance(content, list):
75
+ current_user_content.extend(content)
76
+ else:
77
+ current_user_content.append({"type": "text", "text": str(content)})
78
+ elif role == "assistant":
79
+ if current_user_content:
80
+ messages.append({"role": "user", "content": current_user_content})
81
+ current_user_content = []
82
+ messages.append({"role": "assistant", "content": [{"type": "text", "text": str(content)}]})
83
+ if current_user_content:
84
+ messages.append({"role": "user", "content": current_user_content})
85
+ return messages
86
+
87
+ @spaces.GPU
88
+ def generate_response(input_data, chat_history, model_choice, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty):
89
+ if isinstance(input_data, dict) and "text" in input_data:
90
+ text = input_data["text"]
91
+ files = input_data.get("files", [])
92
+ else:
93
+ text = str(input_data)
94
+ files = []
95
+
96
+ new_message_content = format_message(text, files)
97
+ new_message = {"role": "user", "content": new_message_content}
98
+ system_message = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}] if system_prompt else []
99
+ processed_history = format_conversation_history(chat_history)
100
+ messages = system_message + processed_history
101
+ if messages and messages[-1]["role"] == "user":
102
+ messages[-1]["content"].extend(new_message["content"])
103
+ else:
104
+ messages.append(new_message)
105
+ if model_choice == "Gemma 3 12B":
106
+ model = model_12b
107
+ processor = processor_12b
108
+ else:
109
+ model = model_4b
110
+ processor = processor_4b
111
+ inputs = processor.apply_chat_template(
112
+ messages,
113
+ add_generation_prompt=True,
114
+ tokenize=True,
115
+ return_tensors="pt",
116
+ return_dict=True
117
+ ).to(model.device)
118
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
119
+ generation_kwargs = dict(
120
+ inputs,
121
+ streamer=streamer,
122
+ max_new_tokens=max_new_tokens,
123
+ do_sample=True,
124
+ temperature=temperature,
125
+ top_p=top_p,
126
+ top_k=top_k,
127
+ repetition_penalty=repetition_penalty
128
+ )
129
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
130
+ thread.start()
131
+
132
+ outputs = []
133
+ for text in streamer:
134
+ outputs.append(text)
135
+ yield "".join(outputs)
136
+
137
+ demo = gr.ChatInterface(
138
+ fn=generate_response,
139
+ additional_inputs=[
140
+ gr.Dropdown(
141
+ label="Model",
142
+ choices=["Gemma 3 12B", "Gemma 3 4B"],
143
+ value="Gemma 3 12B"
144
+ ),
145
+ gr.Slider(label="Max new tokens", minimum=100, maximum=2000, step=1, value=512),
146
+ gr.Textbox(
147
+ label="System Prompt",
148
+ value="You are a friendly chatbot. ",
149
+ lines=4,
150
+ placeholder="Change system prompt"
151
+ ),
152
+ gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7),
153
+ gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
154
+ gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=50),
155
+ gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.0),
156
+ ],
157
+ examples=[
158
+ [{"text": "Explain this image", "files": ["examples/image1.jpg"]}],
159
+ ],
160
+ cache_examples=False,
161
+ type="messages",
162
+ description="""
163
+ #Gemma 3
164
+ You can pick your model 12B or 4B, upload images or videos, and adjust settings below to customize your experience.
165
+ """,
166
+ fill_height=True,
167
+ textbox=gr.MultimodalTextbox(
168
+ label="Query Input",
169
+ file_types=["image", "video"],
170
+ file_count="multiple",
171
+ placeholder="Type your message or upload media"
172
+ ),
173
+ stop_btn="Stop Generation",
174
+ multimodal=True,
175
+ theme=gr.themes.Soft(),
176
+ )
177
+
178
+ if __name__ == "__main__":
179
+ demo.launch()