prithivMLmods commited on
Commit
9a23baa
·
verified ·
1 Parent(s): 7f39c2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -137
app.py CHANGED
@@ -1,138 +1,75 @@
1
  import gradio as gr
2
- from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer
3
  from threading import Thread
4
  import re
5
  import time
 
6
  import torch
7
  import spaces
8
- import subprocess
9
- import uuid
10
- import cv2
11
- import numpy as np
12
- from PIL import Image
13
- from io import BytesIO
14
-
15
- # Install flash-attn
16
- subprocess.run(
17
- 'pip install flash-attn --no-build-isolation',
18
- env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
19
- shell=True
20
- )
21
-
22
- # Load processor and model.
23
- processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct")
24
- model = AutoModelForImageTextToText.from_pretrained(
25
- "HuggingFaceTB/SmolVLM2-2.2B-Instruct",
26
- _attn_implementation="flash_attention_2",
27
- torch_dtype=torch.bfloat16
28
- ).to("cuda:0")
29
 
30
- def downsample_video(video_path):
31
- """
32
- Extracts 10 evenly spaced frames from the video at video_path.
33
- Each frame is converted from BGR to RGB and returned as a PIL Image.
34
- """
35
- vidcap = cv2.VideoCapture(video_path)
36
- total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
37
- fps = vidcap.get(cv2.CAP_PROP_FPS)
38
- frames = []
39
- if total_frames <= 0 or fps <= 0:
40
- vidcap.release()
41
- return frames
42
- frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
43
- for i in frame_indices:
44
- vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
45
- success, frame = vidcap.read()
46
- if success:
47
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
48
- pil_image = Image.fromarray(frame)
49
- frames.append((pil_image, round(i / fps, 2)))
50
- vidcap.release()
51
- return frames
52
 
53
  @spaces.GPU
54
- def model_inference(input_dict, history, max_tokens):
 
 
 
55
  text = input_dict["text"]
56
- user_content = []
57
- media_queue = []
58
-
59
- # Process input files.
60
- for file in input_dict.get("files", []):
61
- if file.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")):
62
- media_queue.append({"type": "image", "path": file})
63
- elif file.endswith((".mp4", ".mov", ".avi", ".mkv", ".flv")):
64
- # Extract frames from video using OpenCV.
65
- frames = downsample_video(file)
66
- for frame, timestamp in frames:
67
- temp_file = f"video_frame_{uuid.uuid4().hex}.png"
68
- frame.save(temp_file)
69
- media_queue.append({"type": "image", "path": temp_file})
70
 
71
- # Build the conversation messages.
72
- if not history:
73
- text = text.strip()
74
- # Use only the "<image>" token for inserting images.
75
- if "<image>" in text:
76
- parts = re.split(r'(<image>)', text)
77
- for part in parts:
78
- if part == "<image>" and media_queue:
79
- user_content.append(media_queue.pop(0))
80
- elif part.strip():
81
- user_content.append({"type": "text", "text": part.strip()})
82
- else:
83
- user_content.append({"type": "text", "text": text})
84
- for media in media_queue:
85
- user_content.append(media)
86
- resulting_messages = [{"role": "user", "content": user_content}]
87
  else:
88
- resulting_messages = []
89
- user_content = []
90
- media_queue = []
91
- # Process history: now only image files are expected.
92
- for hist in history:
93
- if hist["role"] == "user" and isinstance(hist["content"], tuple):
94
- file_name = hist["content"][0]
95
- if file_name.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")):
96
- media_queue.append({"type": "image", "path": file_name})
97
- for hist in history:
98
- if hist["role"] == "user" and isinstance(hist["content"], str):
99
- text = hist["content"]
100
- parts = re.split(r'(<image>)', text)
101
- for part in parts:
102
- if part == "<image>" and media_queue:
103
- user_content.append(media_queue.pop(0))
104
- elif part.strip():
105
- user_content.append({"type": "text", "text": part.strip()})
106
- elif hist["role"] == "assistant":
107
- resulting_messages.append({
108
- "role": "user",
109
- "content": user_content
110
- })
111
- resulting_messages.append({
112
- "role": "assistant",
113
- "content": [{"type": "text", "text": hist["content"]}]
114
- })
115
- user_content = []
116
 
117
- if text == "":
 
118
  gr.Error("Please input a query and optionally image(s).")
 
 
119
 
120
- print("resulting_messages", resulting_messages)
121
- inputs = processor.apply_chat_template(
122
- resulting_messages,
123
- add_generation_prompt=True,
124
- tokenize=True,
125
- return_dict=True,
126
- return_tensors="pt",
127
- )
128
- inputs = inputs.to(model.device)
 
129
 
130
- # Generate response with streaming.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
132
- generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_tokens)
 
 
133
  thread = Thread(target=model.generate, kwargs=generation_args)
134
  thread.start()
135
-
136
  yield "..."
137
  buffer = ""
138
  for new_text in streamer:
@@ -140,30 +77,58 @@ def model_inference(input_dict, history, max_tokens):
140
  time.sleep(0.01)
141
  yield buffer
142
 
143
- examples = [
144
- [{"text": "Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}],
145
- [{"text": "What art era does this artpiece <image> belong to?", "files": ["example_images/rococo.jpg"]}],
146
- [{"text": "Describe this image.", "files": ["example_images/mosque.jpg"]}],
147
- [{"text": "When was this purchase made and how much did it cost?", "files": ["example_images/fiche.jpg"]}],
148
- [{"text": "What is the date in this document?", "files": ["example_images/document.jpg"]}],
149
- [{"text": "What is happening in the video?", "files": ["example_images/short.mp4"]}],
150
- ]
151
-
152
  demo = gr.ChatInterface(
153
  fn=model_inference,
154
- title="SmolVLM2: The Smollest Video Model Ever 📺",
155
- description=(
156
- "Play with [SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct) in this demo. "
157
- "To get started, upload an image and text or try one of the examples. "
158
- "This demo doesn't use history for the chat, so every chat you start is a new conversation."
159
- ),
160
- examples=examples,
161
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple"),
162
  stop_btn="Stop Generation",
163
  multimodal=True,
164
- cache_examples=False,
165
- additional_inputs=[gr.Slider(minimum=100, maximum=500, step=50, value=200, label="Max Tokens")],
166
- type="messages"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  )
168
 
169
- demo.launch(debug=True)
 
1
  import gradio as gr
2
+ from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer
3
  from threading import Thread
4
  import re
5
  import time
6
+ from PIL import Image
7
  import torch
8
  import spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ # Load processor and model
11
+ processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct")
12
+ model = AutoModelForVision2Seq.from_pretrained(
13
+ "HuggingFaceTB/SmolVLM-Instruct",
14
+ torch_dtype=torch.bfloat16,
15
+ ).to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  @spaces.GPU
18
+ def model_inference(
19
+ input_dict, history, decoding_strategy, temperature, max_new_tokens,
20
+ repetition_penalty, top_p
21
+ ):
22
  text = input_dict["text"]
23
+ print(input_dict["files"])
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ # Process input images if provided.
26
+ if len(input_dict["files"]) > 1:
27
+ images = [Image.open(image).convert("RGB") for image in input_dict["files"]]
28
+ elif len(input_dict["files"]) == 1:
29
+ images = [Image.open(input_dict["files"][0]).convert("RGB")]
 
 
 
 
 
 
 
 
 
 
 
30
  else:
31
+ images = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ # Validate input
34
+ if text == "" and not images:
35
  gr.Error("Please input a query and optionally image(s).")
36
+ if text == "" and images:
37
+ gr.Error("Please input a text query along with the image(s).")
38
 
39
+ # Prepare prompt using the chat template.
40
+ resulting_messages = [{
41
+ "role": "user",
42
+ "content": [{"type": "image"} for _ in range(len(images))] + [
43
+ {"type": "text", "text": text}
44
+ ]
45
+ }]
46
+ prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
47
+ inputs = processor(text=prompt, images=[images], return_tensors="pt")
48
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
49
 
50
+ # Setup generation parameters.
51
+ generation_args = {
52
+ "max_new_tokens": max_new_tokens,
53
+ "repetition_penalty": repetition_penalty,
54
+ }
55
+ assert decoding_strategy in ["Greedy", "Top P Sampling"]
56
+ if decoding_strategy == "Greedy":
57
+ generation_args["do_sample"] = False
58
+ elif decoding_strategy == "Top P Sampling":
59
+ generation_args["temperature"] = temperature
60
+ generation_args["do_sample"] = True
61
+ generation_args["top_p"] = top_p
62
+
63
+ generation_args.update(inputs)
64
+
65
+ # Generate output with a streaming approach.
66
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
67
+ generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
68
+ generated_text = ""
69
+
70
  thread = Thread(target=model.generate, kwargs=generation_args)
71
  thread.start()
72
+
73
  yield "..."
74
  buffer = ""
75
  for new_text in streamer:
 
77
  time.sleep(0.01)
78
  yield buffer
79
 
80
+ # Define the ChatInterface without examples.
 
 
 
 
 
 
 
 
81
  demo = gr.ChatInterface(
82
  fn=model_inference,
83
+ title="SmolVLM: Small yet Mighty 💫",
84
+ description="Play with [HuggingFaceTB/SmolVLM-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM-Instruct) in this demo. To get started, upload an image and text.",
85
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"),
 
 
 
 
 
86
  stop_btn="Stop Generation",
87
  multimodal=True,
88
+ additional_inputs=[
89
+ gr.Radio(
90
+ ["Top P Sampling", "Greedy"],
91
+ value="Greedy",
92
+ label="Decoding strategy",
93
+ info="Higher values is equivalent to sampling more low-probability tokens.",
94
+ ),
95
+ gr.Slider(
96
+ minimum=0.0,
97
+ maximum=5.0,
98
+ value=0.4,
99
+ step=0.1,
100
+ interactive=True,
101
+ label="Sampling temperature",
102
+ info="Higher values will produce more diverse outputs.",
103
+ ),
104
+ gr.Slider(
105
+ minimum=8,
106
+ maximum=1024,
107
+ value=512,
108
+ step=1,
109
+ interactive=True,
110
+ label="Maximum number of new tokens to generate",
111
+ ),
112
+ gr.Slider(
113
+ minimum=0.01,
114
+ maximum=5.0,
115
+ value=1.2,
116
+ step=0.01,
117
+ interactive=True,
118
+ label="Repetition penalty",
119
+ info="1.0 is equivalent to no penalty",
120
+ ),
121
+ gr.Slider(
122
+ minimum=0.01,
123
+ maximum=0.99,
124
+ value=0.8,
125
+ step=0.01,
126
+ interactive=True,
127
+ label="Top P",
128
+ info="Higher values is equivalent to sampling more low-probability tokens.",
129
+ )
130
+ ],
131
+ cache_examples=False
132
  )
133
 
134
+ demo.launch(debug=True)