prithivMLmods commited on
Commit
cc33eaf
·
verified ·
1 Parent(s): 18bfaa0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -142
app.py CHANGED
@@ -1,33 +1,31 @@
1
  import os
2
- # Disable chunked prefill and asynchronous output before importing vllm.
3
- os.environ["VLLM_ENABLE_CHUNKED_PREFILL"] = "False"
4
- os.environ["VLLM_ENABLE_ASYNC_OUTPUT"] = "False"
5
-
6
- import re
7
  import uuid
8
  import json
9
  import time
10
- import random
11
- import asyncio
12
- import cv2
13
- from datetime import datetime, timedelta
14
  from threading import Thread
 
15
 
16
- import torch
17
  import gradio as gr
18
- import spaces
19
  import numpy as np
20
  from PIL import Image
 
 
21
  from huggingface_hub import hf_hub_download
22
- from vllm import LLM
23
- from vllm.sampling_params import SamplingParams
24
 
25
  # -----------------------------------------------------------------------------
26
- # Helper functions
27
  # -----------------------------------------------------------------------------
 
 
 
28
 
 
 
 
29
  def progress_bar_html(label: str) -> str:
30
- """Return an HTML snippet for a progress bar."""
31
  return f'''
32
  <div style="display: flex; align-items: center;">
33
  <span style="margin-right: 10px; font-size: 14px;">{label}</span>
@@ -43,186 +41,163 @@ def progress_bar_html(label: str) -> str:
43
  </style>
44
  '''
45
 
46
- def downsample_video(video_path: str, num_frames: int = 10):
 
 
 
47
  """
48
- Downsample a video to extract a set number of evenly spaced frames.
49
- Returns a list of tuples (PIL.Image, timestamp in seconds).
 
 
 
 
 
 
 
 
 
 
50
  """
51
  vidcap = cv2.VideoCapture(video_path)
52
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
53
  fps = vidcap.get(cv2.CAP_PROP_FPS)
54
  frames = []
55
- if total_frames <= 0 or fps <= 0:
56
- vidcap.release()
57
- return frames
58
- # Get evenly spaced frame indices.
59
- frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
60
- for i in frame_indices:
61
- vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
62
- success, image = vidcap.read()
63
- if success:
64
- # Convert BGR to RGB and then to a PIL Image.
65
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
66
- pil_image = Image.fromarray(image)
67
- timestamp = round(i / fps, 2)
68
- frames.append((pil_image, timestamp))
69
  vidcap.release()
70
  return frames
71
 
72
- def load_system_prompt(repo_id: str, filename: str) -> str:
73
  """
74
- Load the system prompt from the given Hugging Face Hub repo file,
75
- and format it with the model name and current dates.
 
76
  """
77
- file_path = hf_hub_download(repo_id=repo_id, filename=filename)
78
- with open(file_path, "r") as file:
79
- system_prompt = file.read()
80
- today = datetime.today().strftime("%Y-%m-%d")
81
- yesterday = (datetime.today() - timedelta(days=1)).strftime("%Y-%m-%d")
82
- model_name = repo_id.split("/")[-1]
83
- return system_prompt.format(name=model_name, today=today, yesterday=yesterday)
 
 
 
 
 
 
 
 
84
 
85
  # -----------------------------------------------------------------------------
86
- # Global Settings and Model Initialization
87
  # -----------------------------------------------------------------------------
88
-
89
- # Model details (adjust as needed)
90
  MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
91
- # Load the system prompt from HF Hub (ensure SYSTEM_PROMPT.txt exists in the repo)
92
  SYSTEM_PROMPT = load_system_prompt(MODEL_ID, "SYSTEM_PROMPT.txt")
93
- # Alternatively, you can hardcode the system prompt:
94
- # SYSTEM_PROMPT = "You are a conversational agent that always answers straight to the point, and ends with an ASCII cat."
95
-
96
- # Set the device explicitly.
97
- device = "cuda" if torch.cuda.is_available() else "cpu"
98
 
99
- # Initialize the Mistral LLM via vllm.
100
- # The enforce_eager flag ensures synchronous (eager) output.
101
- llm = LLM(model=MODEL_ID, tokenizer_mode="mistral", device=device, enforce_eager=True)
 
 
 
 
102
 
103
  # -----------------------------------------------------------------------------
104
  # Main Generation Function
105
  # -----------------------------------------------------------------------------
106
- @spaces.GPU
107
  def generate(
108
  input_dict: dict,
109
  chat_history: list,
110
- max_new_tokens: int = 512,
111
- temperature: float = 0.15,
112
  top_p: float = 0.9,
113
  top_k: int = 50,
 
114
  ):
115
- """
116
- The main generation function for the Mistral chatbot.
117
- It supports:
118
- - Text-only inference.
119
- - Image inference (attaches image file paths).
120
- - Video inference (extracts and attaches sampled video frames).
121
- """
122
- text = input_dict["text"]
123
  files = input_dict.get("files", [])
124
- # Prepare the conversation with a system prompt.
125
- messages = [
126
- {"role": "system", "content": SYSTEM_PROMPT}
127
- ]
128
 
129
- # Check if any file is provided.
130
  video_extensions = (".mp4", ".mov", ".avi", ".mkv", ".webm")
131
- if files:
132
- # If any file is a video, use video inference branch.
133
- if any(str(f).lower().endswith(video_extensions) for f in files):
134
- # Remove any @video-infer tag if present.
135
- prompt_clean = re.sub(r"@video-infer", "", text, flags=re.IGNORECASE).strip().strip('"')
136
- video_path = files[0] # currently process the first video file
137
- frames = downsample_video(video_path)
138
- # Build a list that contains the prompt plus each frame information.
139
- user_content = [{"type": "text", "text": prompt_clean}]
140
- for frame in frames:
141
- image, timestamp = frame
142
- # Save the frame to a temporary file.
143
- image_path = f"video_frame_{uuid.uuid4().hex}.png"
144
- image.save(image_path)
145
- user_content.append({"type": "text", "text": f"Frame at {timestamp} seconds:"})
146
- user_content.append({"type": "image_path", "image_path": image_path})
147
- messages.append({"role": "user", "content": user_content})
148
- else:
149
- # Assume provided files are images.
150
- prompt_clean = re.sub(r"@mistral", "", text, flags=re.IGNORECASE).strip().strip('"')
151
- user_content = [{"type": "text", "text": prompt_clean}]
152
- for file in files:
153
- try:
154
- image = Image.open(file)
155
- image_path = f"image_{uuid.uuid4().hex}.png"
156
- image.save(image_path)
157
- user_content.append({"type": "image_path", "image_path": image_path})
158
- except Exception as e:
159
- user_content.append({"type": "text", "text": f"Could not open file {file}"})
160
- messages.append({"role": "user", "content": user_content})
161
- else:
162
- # Text-only branch.
163
- messages.append({"role": "user", "content": [{"type": "text", "text": text}]})
164
 
165
- # Show a progress bar before generating.
166
- yield progress_bar_html("Processing with Mistral")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
- # Set up sampling parameters.
169
- sampling_params = SamplingParams(
170
- max_tokens=max_new_tokens,
171
- temperature=temperature,
172
- top_p=top_p,
173
- top_k=top_k
174
- )
175
- # Run the chat (synchronously) using vllm.
176
- outputs = llm.chat(messages, sampling_params=sampling_params)
177
- final_response = outputs[0].outputs[0].text
178
-
179
- # Simulate streaming output by chunking the result.
180
  buffer = ""
181
- chunk_size = 20 # number of characters per chunk
182
- for i in range(0, len(final_response), chunk_size):
183
- buffer = final_response[: i + chunk_size]
 
184
  yield buffer
185
- time.sleep(0.05)
186
- return
187
 
188
  # -----------------------------------------------------------------------------
189
- # Gradio Interface Setup
190
  # -----------------------------------------------------------------------------
191
-
192
  demo = gr.ChatInterface(
193
  fn=generate,
194
  additional_inputs=[
195
- gr.Slider(label="Max new tokens", minimum=1, maximum=1024, step=1, value=512),
196
- gr.Slider(label="Temperature", minimum=0.05, maximum=2.0, step=0.05, value=0.15),
197
- gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
198
  gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
 
199
  ],
200
  examples=[
201
- # Example with text only.
202
- ["Explain the significance of today in the context of current events."],
203
- # Example with image files (ensure you have valid image paths).
204
- [{
205
- "text": "Describe what you see in the image.",
206
- "files": ["examples/3.jpg"]
207
- }],
208
- # Example with video file (ensure you have a valid video file).
209
- [{
210
- "text": "@video-infer Summarize the events shown in the video.",
211
- "files": ["examples/sample_video.mp4"]
212
- }],
213
  ],
214
  cache_examples=False,
215
  type="messages",
216
- description="# **Mistral Multimodal Chatbot** \nSupports text, image (by reference) and video inference. Use @video-infer in your query when providing a video.",
217
  fill_height=True,
218
  textbox=gr.MultimodalTextbox(
219
  label="Query Input",
220
  file_types=["image", "video"],
221
  file_count="multiple",
222
- placeholder="Enter your query here. Tag with @video-infer if using a video file."
223
  ),
224
  stop_btn="Stop Generation",
225
- examples_per_page=3,
226
  )
227
 
228
  if __name__ == "__main__":
 
1
  import os
2
+ import random
 
 
 
 
3
  import uuid
4
  import json
5
  import time
6
+ import re
 
 
 
7
  from threading import Thread
8
+ from datetime import datetime, timedelta
9
 
 
10
  import gradio as gr
11
+ import torch
12
  import numpy as np
13
  from PIL import Image
14
+ import cv2
15
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
16
  from huggingface_hub import hf_hub_download
 
 
17
 
18
  # -----------------------------------------------------------------------------
19
+ # Constants & Device Setup
20
  # -----------------------------------------------------------------------------
21
+ MAX_MAX_NEW_TOKENS = 2048
22
+ DEFAULT_MAX_NEW_TOKENS = 1024
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
 
25
+ # -----------------------------------------------------------------------------
26
+ # Helper Functions
27
+ # -----------------------------------------------------------------------------
28
  def progress_bar_html(label: str) -> str:
 
29
  return f'''
30
  <div style="display: flex; align-items: center;">
31
  <span style="margin-right: 10px; font-size: 14px;">{label}</span>
 
41
  </style>
42
  '''
43
 
44
+ def load_system_prompt(repo_id: str, filename: str) -> str:
45
+ """
46
+ Download and load a system prompt template from the given Hugging Face repo.
47
+ The template may include placeholders (e.g. {name}, {today}, {yesterday}) that get formatted.
48
  """
49
+ file_path = hf_hub_download(repo_id=repo_id, filename=filename)
50
+ with open(file_path, "r") as file:
51
+ system_prompt = file.read()
52
+ today = datetime.today().strftime("%Y-%m-%d")
53
+ yesterday = (datetime.today() - timedelta(days=1)).strftime("%Y-%m-%d")
54
+ model_name = repo_id.split("/")[-1]
55
+ return system_prompt.format(name=model_name, today=today, yesterday=yesterday)
56
+
57
+ def downsample_video(video_path: str):
58
+ """
59
+ Extracts 10 evenly spaced frames from the video.
60
+ Returns a list of tuples (PIL.Image, timestamp_in_seconds).
61
  """
62
  vidcap = cv2.VideoCapture(video_path)
63
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
64
  fps = vidcap.get(cv2.CAP_PROP_FPS)
65
  frames = []
66
+ if total_frames > 0 and fps > 0:
67
+ frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
68
+ for i in frame_indices:
69
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
70
+ success, image = vidcap.read()
71
+ if success:
72
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
73
+ pil_image = Image.fromarray(image)
74
+ timestamp = round(i / fps, 2)
75
+ frames.append((pil_image, timestamp))
 
 
 
 
76
  vidcap.release()
77
  return frames
78
 
79
+ def build_prompt(chat_history, current_input_text, video_frames=None, image_files=None):
80
  """
81
+ Build a conversation prompt string.
82
+ The system prompt is added first, then previous chat history, and finally the current input.
83
+ If video_frames or image_files are provided, a note is added in the prompt.
84
  """
85
+ prompt = f"System: {SYSTEM_PROMPT}\n"
86
+ # Append chat history (if any)
87
+ for msg in chat_history:
88
+ role = msg.get("role", "").capitalize()
89
+ content = msg.get("content", "")
90
+ prompt += f"{role}: {content}\n"
91
+ prompt += f"User: {current_input_text}\n"
92
+ if video_frames:
93
+ for _, timestamp in video_frames:
94
+ prompt += f"[Video Frame at {timestamp} sec]\n"
95
+ if image_files:
96
+ for _ in image_files:
97
+ prompt += "[Image Input]\n"
98
+ prompt += "Assistant: "
99
+ return prompt
100
 
101
  # -----------------------------------------------------------------------------
102
+ # Load Mistral Model & System Prompt
103
  # -----------------------------------------------------------------------------
 
 
104
  MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
 
105
  SYSTEM_PROMPT = load_system_prompt(MODEL_ID, "SYSTEM_PROMPT.txt")
 
 
 
 
 
106
 
107
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
108
+ model = AutoModelForCausalLM.from_pretrained(
109
+ MODEL_ID,
110
+ torch_dtype=torch.float16,
111
+ device_map="auto"
112
+ ).to(device)
113
+ model.eval()
114
 
115
  # -----------------------------------------------------------------------------
116
  # Main Generation Function
117
  # -----------------------------------------------------------------------------
 
118
  def generate(
119
  input_dict: dict,
120
  chat_history: list,
121
+ max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
122
+ temperature: float = 0.6,
123
  top_p: float = 0.9,
124
  top_k: int = 50,
125
+ repetition_penalty: float = 1.2,
126
  ):
127
+ text = input_dict.get("text", "")
 
 
 
 
 
 
 
128
  files = input_dict.get("files", [])
 
 
 
 
129
 
130
+ # Separate video files from images based on file extension.
131
  video_extensions = (".mp4", ".mov", ".avi", ".mkv", ".webm")
132
+ video_files = [f for f in files if str(f).lower().endswith(video_extensions)]
133
+ image_files = [f for f in files if not str(f).lower().endswith(video_extensions)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
+ video_frames = None
136
+ if video_files:
137
+ # Process the first video file.
138
+ video_path = video_files[0]
139
+ video_frames = downsample_video(video_path)
140
+
141
+ # Build the full prompt from the system prompt, chat history, current text, and file inputs.
142
+ prompt = build_prompt(chat_history, text, video_frames, image_files)
143
+
144
+ # Tokenize the prompt.
145
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
146
+
147
+ # Set up a streamer for incremental output.
148
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=20.0)
149
+
150
+ generation_kwargs = {
151
+ "input_ids": inputs["input_ids"],
152
+ "max_new_tokens": max_new_tokens,
153
+ "do_sample": True,
154
+ "temperature": temperature,
155
+ "top_p": top_p,
156
+ "top_k": top_k,
157
+ "repetition_penalty": repetition_penalty,
158
+ "streamer": streamer,
159
+ }
160
+
161
+ # Launch generation in a separate thread.
162
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
163
+ thread.start()
164
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  buffer = ""
166
+ yield progress_bar_html("Processing with Mistral")
167
+ for new_text in streamer:
168
+ buffer += new_text
169
+ time.sleep(0.01)
170
  yield buffer
 
 
171
 
172
  # -----------------------------------------------------------------------------
173
+ # Gradio Interface
174
  # -----------------------------------------------------------------------------
 
175
  demo = gr.ChatInterface(
176
  fn=generate,
177
  additional_inputs=[
178
+ gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
179
+ gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
180
+ gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
181
  gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
182
+ gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
183
  ],
184
  examples=[
185
+ [{"text": "Describe the content of the video.", "files": ["examples/sample_video.mp4"]}],
186
+ [{"text": "Explain what is in this image.", "files": ["examples/sample_image.jpg"]}],
187
+ ["Tell me a fun fact about space."],
 
 
 
 
 
 
 
 
 
188
  ],
189
  cache_examples=False,
190
  type="messages",
191
+ description="# **Mistral Chatbot with Video Inference**\nA chatbot built with Mistral (via Transformers) that supports text, image, and video (frame extraction) inputs.",
192
  fill_height=True,
193
  textbox=gr.MultimodalTextbox(
194
  label="Query Input",
195
  file_types=["image", "video"],
196
  file_count="multiple",
197
+ placeholder="Type your message here. Optionally attach images or video."
198
  ),
199
  stop_btn="Stop Generation",
200
+ multimodal=True,
201
  )
202
 
203
  if __name__ == "__main__":