prithivMLmods commited on
Commit
a9be97c
·
verified ·
1 Parent(s): 0109e78

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -16
app.py CHANGED
@@ -1,16 +1,38 @@
1
  import gradio as gr
2
  from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer
3
  from threading import Thread
4
- import re
5
  import time
6
  from PIL import Image
7
  import torch
8
  import spaces
 
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct")
11
- model = AutoModelForVision2Seq.from_pretrained("HuggingFaceTB/SmolVLM-Instruct",
12
- torch_dtype=torch.bfloat16,
13
- ).to("cuda")
 
14
 
15
  @spaces.GPU
16
  def model_inference(
@@ -18,8 +40,62 @@ def model_inference(
18
  repetition_penalty, top_p
19
  ):
20
  text = input_dict["text"]
21
- print(input_dict["files"])
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  # Process input images if provided.
24
  if len(input_dict["files"]) > 1:
25
  images = [Image.open(image).convert("RGB") for image in input_dict["files"]]
@@ -28,13 +104,12 @@ def model_inference(
28
  else:
29
  images = []
30
 
31
- # Validate input
32
  if text == "" and not images:
33
  gr.Error("Please input a query and optionally image(s).")
34
  if text == "" and images:
35
  gr.Error("Please input a text query along with the image(s).")
36
 
37
- # Prepare prompt using the chat template.
38
  resulting_messages = [{
39
  "role": "user",
40
  "content": [{"type": "image"} for _ in range(len(images))] + [
@@ -45,7 +120,6 @@ def model_inference(
45
  inputs = processor(text=prompt, images=[images], return_tensors="pt")
46
  inputs = {k: v.to("cuda") for k, v in inputs.items()}
47
 
48
- # Setup generation parameters.
49
  generation_args = {
50
  "max_new_tokens": max_new_tokens,
51
  "repetition_penalty": repetition_penalty,
@@ -60,26 +134,26 @@ def model_inference(
60
 
61
  generation_args.update(inputs)
62
 
63
- # Generate output with a streaming approach.
64
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
65
  generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
66
- generated_text = ""
67
-
68
  thread = Thread(target=model.generate, kwargs=generation_args)
69
  thread.start()
70
-
71
  yield "..."
72
- buffer = ""
73
  for new_text in streamer:
74
  buffer += new_text
75
  time.sleep(0.01)
76
  yield buffer
77
 
78
- # Define the ChatInterface without examples.
79
  demo = gr.ChatInterface(
80
  fn=model_inference,
81
  description="# **SmolVLM Video Infer**",
82
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"),
 
 
 
 
83
  stop_btn="Stop Generation",
84
  multimodal=True,
85
  additional_inputs=[
@@ -128,4 +202,4 @@ demo = gr.ChatInterface(
128
  cache_examples=False
129
  )
130
 
131
- demo.launch(debug=True)
 
1
  import gradio as gr
2
  from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer
3
  from threading import Thread
 
4
  import time
5
  from PIL import Image
6
  import torch
7
  import spaces
8
+ import cv2
9
+ import numpy as np
10
 
11
+ # Helper: Downsample video to extract a fixed number of frames.
12
+ def downsample_video(video_path, num_frames=10):
13
+ cap = cv2.VideoCapture(video_path)
14
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
15
+ fps = cap.get(cv2.CAP_PROP_FPS)
16
+ # Calculate evenly spaced frame indices.
17
+ frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
18
+ frames = []
19
+ for idx in frame_indices:
20
+ cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
21
+ ret, frame = cap.read()
22
+ if ret:
23
+ # Convert BGR to RGB and then to a PIL image.
24
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
25
+ frame = Image.fromarray(frame)
26
+ frames.append(frame)
27
+ cap.release()
28
+ return frames
29
+
30
+ # Load processor and model.
31
  processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct")
32
+ model = AutoModelForVision2Seq.from_pretrained(
33
+ "HuggingFaceTB/SmolVLM-Instruct",
34
+ torch_dtype=torch.bfloat16,
35
+ ).to("cuda")
36
 
37
  @spaces.GPU
38
  def model_inference(
 
40
  repetition_penalty, top_p
41
  ):
42
  text = input_dict["text"]
 
43
 
44
+ # --- Video Inference Branch ---
45
+ if text.lower().startswith("@video-infer"):
46
+ # Remove the command prefix to get the prompt.
47
+ prompt_text = text[len("@video-infer"):].strip()
48
+ if not input_dict["files"]:
49
+ yield "Error: Please provide a video file for @video-infer."
50
+ return
51
+ # Assume the first file is a video.
52
+ video_file = input_dict["files"][0]
53
+ frames = downsample_video(video_file)
54
+ if not frames:
55
+ yield "Error: Could not extract frames from the video."
56
+ return
57
+ # Build a chat content: include the user prompt and then each frame labeled.
58
+ content = [{"type": "text", "text": prompt_text}]
59
+ for idx, frame in enumerate(frames):
60
+ content.append({"type": "text", "text": f"Frame {idx+1}:"})
61
+ content.append({"type": "image", "image": frame})
62
+ resulting_messages = [{
63
+ "role": "user",
64
+ "content": content
65
+ }]
66
+ prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
67
+ # Process the extracted frames as images.
68
+ inputs = processor(text=prompt, images=[frames], return_tensors="pt")
69
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
70
+
71
+ # Setup generation parameters.
72
+ generation_args = {
73
+ "max_new_tokens": max_new_tokens,
74
+ "repetition_penalty": repetition_penalty,
75
+ }
76
+ assert decoding_strategy in ["Greedy", "Top P Sampling"]
77
+ if decoding_strategy == "Greedy":
78
+ generation_args["do_sample"] = False
79
+ elif decoding_strategy == "Top P Sampling":
80
+ generation_args["temperature"] = temperature
81
+ generation_args["do_sample"] = True
82
+ generation_args["top_p"] = top_p
83
+
84
+ generation_args.update(inputs)
85
+
86
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
87
+ generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
88
+ buffer = ""
89
+ thread = Thread(target=model.generate, kwargs=generation_args)
90
+ thread.start()
91
+ yield "..."
92
+ for new_text in streamer:
93
+ buffer += new_text
94
+ time.sleep(0.01)
95
+ yield buffer
96
+ return
97
+
98
+ # --- Default Image Inference Branch ---
99
  # Process input images if provided.
100
  if len(input_dict["files"]) > 1:
101
  images = [Image.open(image).convert("RGB") for image in input_dict["files"]]
 
104
  else:
105
  images = []
106
 
107
+ # Validate input.
108
  if text == "" and not images:
109
  gr.Error("Please input a query and optionally image(s).")
110
  if text == "" and images:
111
  gr.Error("Please input a text query along with the image(s).")
112
 
 
113
  resulting_messages = [{
114
  "role": "user",
115
  "content": [{"type": "image"} for _ in range(len(images))] + [
 
120
  inputs = processor(text=prompt, images=[images], return_tensors="pt")
121
  inputs = {k: v.to("cuda") for k, v in inputs.items()}
122
 
 
123
  generation_args = {
124
  "max_new_tokens": max_new_tokens,
125
  "repetition_penalty": repetition_penalty,
 
134
 
135
  generation_args.update(inputs)
136
 
 
137
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
138
  generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
139
+ buffer = ""
 
140
  thread = Thread(target=model.generate, kwargs=generation_args)
141
  thread.start()
 
142
  yield "..."
 
143
  for new_text in streamer:
144
  buffer += new_text
145
  time.sleep(0.01)
146
  yield buffer
147
 
148
+ # Gradio ChatInterface: Allow both image and video file types.
149
  demo = gr.ChatInterface(
150
  fn=model_inference,
151
  description="# **SmolVLM Video Infer**",
152
+ textbox=gr.MultimodalTextbox(
153
+ label="Query Input",
154
+ file_types=["image", "video"],
155
+ file_count="multiple"
156
+ ),
157
  stop_btn="Stop Generation",
158
  multimodal=True,
159
  additional_inputs=[
 
202
  cache_examples=False
203
  )
204
 
205
+ demo.launch(debug=True)