prithivMLmods commited on
Commit
9dc7658
·
verified ·
1 Parent(s): 86a82e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -52
app.py CHANGED
@@ -6,10 +6,20 @@ import time
6
  import torch
7
  import spaces
8
  import subprocess
 
 
 
 
 
9
 
10
- # Install flash-attn with no CUDA build
11
- subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
 
 
 
 
12
 
 
13
  processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct")
14
  model = AutoModelForImageTextToText.from_pretrained(
15
  "HuggingFaceTB/SmolVLM2-2.2B-Instruct",
@@ -17,66 +27,112 @@ model = AutoModelForImageTextToText.from_pretrained(
17
  torch_dtype=torch.bfloat16
18
  ).to("cuda:0")
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  @spaces.GPU
21
- def model_inference(input_dict, history, max_tokens):
22
- text = input_dict.get("text", "").strip()
23
- media_queue = []
24
  user_content = []
25
-
26
- # Process uploaded media files
 
27
  for file in input_dict.get("files", []):
28
  if file.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")):
29
  media_queue.append({"type": "image", "path": file})
30
  elif file.endswith((".mp4", ".mov", ".avi", ".mkv", ".flv")):
31
- media_queue.append({"type": "video", "path": file})
32
-
33
- # Construct user content with placeholders
34
- if "<image>" in text or "<video>" in text:
35
- parts = re.split(r'(<image>|<video>)', text)
36
- for part in parts:
37
- if part == "<image>" and media_queue:
38
- user_content.append(media_queue.pop(0))
39
- elif part == "<video>" and media_queue:
40
- user_content.append(media_queue.pop(0))
41
- elif part.strip():
42
- user_content.append({"type": "text", "text": part.strip()})
 
 
 
 
 
 
 
 
 
 
 
43
  else:
44
- user_content.append({"type": "text", "text": text})
45
- user_content.extend(media_queue)
46
-
47
- resulting_messages = [{"role": "user", "content": user_content}]
48
-
49
- # Process history
50
- if history:
51
  for hist in history:
52
- if hist["role"] == "user":
53
- if isinstance(hist["content"], tuple) and len(hist["content"]) > 0:
54
- file_name = hist["content"][0]
55
- if file_name.endswith((".png", ".jpg", ".jpeg")):
56
- media_queue.append({"type": "image", "path": file_name})
57
- elif file_name.endswith(".mp4"):
58
- media_queue.append({"type": "video", "path": file_name})
59
-
 
 
 
 
 
60
  elif hist["role"] == "assistant":
61
- resulting_messages.append({"role": "assistant", "content": [{"type": "text", "text": hist["content"]}]})
62
-
63
- if not text and not media_queue:
64
- gr.Warning("Please provide text or an image/video.")
65
-
66
- # Process inputs
 
 
 
 
 
 
 
 
67
  inputs = processor.apply_chat_template(
68
  resulting_messages,
69
  add_generation_prompt=True,
70
  tokenize=True,
71
  return_dict=True,
72
- return_tensors="pt"
73
- ).to(model.device, dtype=torch.bfloat16) # Ensure dtype consistency
74
-
75
- # Generate output
76
- streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
77
- thread = Thread(target=model.generate, kwargs={"input_ids": inputs["input_ids"], "streamer": streamer, "max_new_tokens": max_tokens})
 
 
78
  thread.start()
79
-
80
  yield "..."
81
  buffer = ""
82
  for new_text in streamer:
@@ -84,16 +140,30 @@ def model_inference(input_dict, history, max_tokens):
84
  time.sleep(0.01)
85
  yield buffer
86
 
 
 
 
 
 
 
 
 
 
87
  demo = gr.ChatInterface(
88
  fn=model_inference,
89
- title="SmolVLM2: The Smollest Video Model Ever 📺",
90
- description="Play with [SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct) in this demo. To get started, upload an image and text.",
91
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", ".mp4"], file_count="multiple"),
92
- stop_btn="Stop Generation",
 
 
 
 
 
93
  multimodal=True,
94
  cache_examples=False,
95
  additional_inputs=[gr.Slider(minimum=100, maximum=500, step=50, value=200, label="Max Tokens")],
96
  type="messages"
97
  )
98
 
99
- demo.launch(debug=True, share=True)
 
6
  import torch
7
  import spaces
8
  import subprocess
9
+ import uuid
10
+ import cv2
11
+ import numpy as np
12
+ from PIL import Image
13
+ from io import BytesIO
14
 
15
+ # Install flash-attn
16
+ subprocess.run(
17
+ 'pip install flash-attn --no-build-isolation',
18
+ env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
19
+ shell=True
20
+ )
21
 
22
+ # Load processor and model.
23
  processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct")
24
  model = AutoModelForImageTextToText.from_pretrained(
25
  "HuggingFaceTB/SmolVLM2-2.2B-Instruct",
 
27
  torch_dtype=torch.bfloat16
28
  ).to("cuda:0")
29
 
30
+ def downsample_video(video_path):
31
+ """
32
+ Extracts 10 evenly spaced frames from the video at video_path.
33
+ Each frame is converted from BGR to RGB and returned as a PIL Image.
34
+ """
35
+ vidcap = cv2.VideoCapture(video_path)
36
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
37
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
38
+ frames = []
39
+ if total_frames <= 0 or fps <= 0:
40
+ vidcap.release()
41
+ return frames
42
+ frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
43
+ for i in frame_indices:
44
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
45
+ success, frame = vidcap.read()
46
+ if success:
47
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
48
+ pil_image = Image.fromarray(frame)
49
+ frames.append((pil_image, round(i / fps, 2)))
50
+ vidcap.release()
51
+ return frames
52
+
53
  @spaces.GPU
54
+ def model_inference(input_dict, history, max_tokens):
55
+ text = input_dict["text"]
 
56
  user_content = []
57
+ media_queue = []
58
+
59
+ # Process input files.
60
  for file in input_dict.get("files", []):
61
  if file.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")):
62
  media_queue.append({"type": "image", "path": file})
63
  elif file.endswith((".mp4", ".mov", ".avi", ".mkv", ".flv")):
64
+ # Extract frames from video using OpenCV.
65
+ frames = downsample_video(file)
66
+ for frame, timestamp in frames:
67
+ temp_file = f"video_frame_{uuid.uuid4().hex}.png"
68
+ frame.save(temp_file)
69
+ media_queue.append({"type": "image", "path": temp_file})
70
+
71
+ # Build the conversation messages.
72
+ if not history:
73
+ text = text.strip()
74
+ # Use only the "<image>" token for inserting images.
75
+ if "<image>" in text:
76
+ parts = re.split(r'(<image>)', text)
77
+ for part in parts:
78
+ if part == "<image>" and media_queue:
79
+ user_content.append(media_queue.pop(0))
80
+ elif part.strip():
81
+ user_content.append({"type": "text", "text": part.strip()})
82
+ else:
83
+ user_content.append({"type": "text", "text": text})
84
+ for media in media_queue:
85
+ user_content.append(media)
86
+ resulting_messages = [{"role": "user", "content": user_content}]
87
  else:
88
+ resulting_messages = []
89
+ user_content = []
90
+ media_queue = []
91
+ # Process history: now only image files are expected.
 
 
 
92
  for hist in history:
93
+ if hist["role"] == "user" and isinstance(hist["content"], tuple):
94
+ file_name = hist["content"][0]
95
+ if file_name.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")):
96
+ media_queue.append({"type": "image", "path": file_name})
97
+ for hist in history:
98
+ if hist["role"] == "user" and isinstance(hist["content"], str):
99
+ text = hist["content"]
100
+ parts = re.split(r'(<image>)', text)
101
+ for part in parts:
102
+ if part == "<image>" and media_queue:
103
+ user_content.append(media_queue.pop(0))
104
+ elif part.strip():
105
+ user_content.append({"type": "text", "text": part.strip()})
106
  elif hist["role"] == "assistant":
107
+ resulting_messages.append({
108
+ "role": "user",
109
+ "content": user_content
110
+ })
111
+ resulting_messages.append({
112
+ "role": "assistant",
113
+ "content": [{"type": "text", "text": hist["content"]}]
114
+ })
115
+ user_content = []
116
+
117
+ if text == "":
118
+ gr.Error("Please input a query and optionally image(s).")
119
+
120
+ print("resulting_messages", resulting_messages)
121
  inputs = processor.apply_chat_template(
122
  resulting_messages,
123
  add_generation_prompt=True,
124
  tokenize=True,
125
  return_dict=True,
126
+ return_tensors="pt",
127
+ )
128
+ inputs = inputs.to(model.device)
129
+
130
+ # Generate response with streaming.
131
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
132
+ generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_tokens)
133
+ thread = Thread(target=model.generate, kwargs=generation_args)
134
  thread.start()
135
+
136
  yield "..."
137
  buffer = ""
138
  for new_text in streamer:
 
140
  time.sleep(0.01)
141
  yield buffer
142
 
143
+ examples = [
144
+ [{"text": "Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}],
145
+ [{"text": "What art era does this artpiece <image> belong to?", "files": ["example_images/rococo.jpg"]}],
146
+ [{"text": "Describe this image.", "files": ["example_images/mosque.jpg"]}],
147
+ [{"text": "When was this purchase made and how much did it cost?", "files": ["example_images/fiche.jpg"]}],
148
+ [{"text": "What is the date in this document?", "files": ["example_images/document.jpg"]}],
149
+ [{"text": "What is happening in the video?", "files": ["example_images/short.mp4"]}],
150
+ ]
151
+
152
  demo = gr.ChatInterface(
153
  fn=model_inference,
154
+ title="SmolVLM2: The Smollest Video Model Ever 📺",
155
+ description=(
156
+ "Play with [SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct) in this demo. "
157
+ "To get started, upload an image and text or try one of the examples. "
158
+ "This demo doesn't use history for the chat, so every chat you start is a new conversation."
159
+ ),
160
+ examples=examples,
161
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple"),
162
+ stop_btn="Stop Generation",
163
  multimodal=True,
164
  cache_examples=False,
165
  additional_inputs=[gr.Slider(minimum=100, maximum=500, step=50, value=200, label="Max Tokens")],
166
  type="messages"
167
  )
168
 
169
+ demo.launch(debug=True)