import gradio as gr import cv2 import torch from PIL import Image from pathlib import Path from threading import Thread from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer import spaces import time # model config model_12b_name = "google/gemma-3-12b-it" model_4b_name = "google/gemma-3-4b-it" model_12b = Gemma3ForConditionalGeneration.from_pretrained( model_12b_name, device_map="auto", torch_dtype=torch.bfloat16 ).eval() processor_12b = AutoProcessor.from_pretrained(model_12b_name) model_4b = Gemma3ForConditionalGeneration.from_pretrained( model_4b_name, device_map="auto", torch_dtype=torch.bfloat16 ).eval() processor_4b = AutoProcessor.from_pretrained(model_4b_name) # I will add timestamp later def extract_video_frames(video_path, num_frames=8): cap = cv2.VideoCapture(video_path) frames = [] total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) step = max(total_frames // num_frames, 1) for i in range(num_frames): cap.set(cv2.CAP_PROP_POS_FRAMES, i * step) ret, frame = cap.read() if ret: frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(Image.fromarray(frame)) cap.release() return frames def format_message(content, files): message_content = [] if content: parts = content.split('') for i, part in enumerate(parts): if part.strip(): message_content.append({"type": "text", "text": part.strip()}) if i < len(parts) - 1 and files: img = Image.open(files.pop(0)) message_content.append({"type": "image", "image": img}) for file in files: file_path = file if isinstance(file, str) else file.name if Path(file_path).suffix.lower() in ['.jpg', '.jpeg', '.png']: img = Image.open(file_path) message_content.append({"type": "image", "image": img}) elif Path(file_path).suffix.lower() in ['.mp4', '.mov']: frames = extract_video_frames(file_path) for frame in frames: message_content.append({"type": "image", "image": frame}) return message_content def format_conversation_history(chat_history): messages = [] current_user_content = [] for item in chat_history: role = item["role"] content = item["content"] if role == "user": if isinstance(content, str): current_user_content.append({"type": "text", "text": content}) elif isinstance(content, list): current_user_content.extend(content) else: current_user_content.append({"type": "text", "text": str(content)}) elif role == "assistant": if current_user_content: messages.append({"role": "user", "content": current_user_content}) current_user_content = [] messages.append({"role": "assistant", "content": [{"type": "text", "text": str(content)}]}) if current_user_content: messages.append({"role": "user", "content": current_user_content}) return messages @spaces.GPU(duration=120) def generate_response(input_data, chat_history, model_choice, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty): if isinstance(input_data, dict) and "text" in input_data: text = input_data["text"] files = input_data.get("files", []) else: text = str(input_data) files = [] new_message_content = format_message(text, files) new_message = {"role": "user", "content": new_message_content} system_message = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}] if system_prompt else [] processed_history = format_conversation_history(chat_history) messages = system_message + processed_history if messages and messages[-1]["role"] == "user": messages[-1]["content"].extend(new_message["content"]) else: messages.append(new_message) if model_choice == "Gemma 3 12B": model = model_12b processor = processor_12b else: model = model_4b processor = processor_4b inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True ).to(model.device) streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict( inputs, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() outputs = [] for text in streamer: outputs.append(text) yield "".join(outputs) demo = gr.ChatInterface( fn=generate_response, additional_inputs=[ gr.Dropdown( label="Model", choices=["Gemma 3 12B", "Gemma 3 4B"], value="Gemma 3 12B" ), gr.Slider(label="Max new tokens", minimum=100, maximum=2000, step=1, value=512), gr.Textbox( label="System Prompt", value="You are a friendly chatbot. ", lines=4, placeholder="Change system prompt" ), gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7), gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.9), gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=50), gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.0), ], examples=[ [{"text": "Explain this image", "files": ["examples/image1.jpg"]}], ], cache_examples=False, type="messages", description=""" # Gemma 3 You can pick your model 12B or 4B, upload images or videos, and adjust settings below to customize your experience. """, fill_height=True, textbox=gr.MultimodalTextbox( label="Query Input", file_types=["image", "video"], file_count="multiple", placeholder="Type your message or upload media" ), stop_btn="Stop Generation", multimodal=True, theme=gr.themes.Soft(), ) if __name__ == "__main__": demo.launch()